Skip to content

Commit 2728d40

Browse files
EmilyBourneyguclu
andauthored
Use PyccelAstNode.pyccel_staging in SemanticParser (pyccel#1839)
Use the `pyccel_staging` property of `PyccelAstNode` objects to facilitate the creation of partially syntactic object (e.g. when unravelling an assignment where the `rhs` has been visited and the `lhs` is an inhomogeneous tuple). This is done by exiting `SemanticParser._visit` immediately if the object passed to the function is already a semantic object. As the property is now always used it is obvious when it has not been correctly set. These cases are therefore fixed. --------- Co-authored-by: Yaman Güçlü <[email protected]>
1 parent 046e93f commit 2728d40

File tree

2 files changed

+50
-9
lines changed

2 files changed

+50
-9
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ All notable changes to this project will be documented in this file.
5656
- \[INTERNALS\] Remove the `order` argument from the `pyccel.ast.core.Allocate` constructor.
5757
- \[INTERNALS\] Remove `rank` and `order` arguments from `pyccel.ast.variable.Variable` constructor.
5858
- \[INTERNALS\] Ensure `SemanticParser.infer_type` returns all documented information.
59+
- \[INTERNALS\] Enforce correct value for `pyccel_staging` property of `PyccelAstNode`.
60+
- \[INTERNALS\] Allow visiting objects containing both syntactic and semantic elements in `SemanticParser`.
5961

6062
### Deprecated
6163

pyccel/parser/semantic.py

Lines changed: 48 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -775,9 +775,11 @@ def is_int(a):
775775
else:
776776
return self._visit(var[indices[0]][indices[1:]])
777777
else:
778+
pyccel_stage.set_stage('syntactic')
778779
tmp_var = PyccelSymbol(self.scope.get_new_name())
779780
assign = Assign(tmp_var, var)
780781
assign.set_current_ast(expr.python_ast)
782+
pyccel_stage.set_stage('semantic')
781783
self._additional_exprs[-1].append(self._visit(assign))
782784
var = self._visit(tmp_var)
783785

@@ -1074,9 +1076,11 @@ def _handle_function(self, expr, func, args, is_method = False):
10741076

10751077
func_results = func.results if isinstance(func, FunctionDef) else func.functions[0].results
10761078
if not parent_assign and len(func_results) == 1 and func_results[0].var.rank > 0:
1079+
pyccel_stage.set_stage('syntactic')
10771080
tmp_var = PyccelSymbol(self.scope.get_new_name())
10781081
assign = Assign(tmp_var, expr)
10791082
assign.set_current_ast(expr.python_ast)
1083+
pyccel_stage.set_stage('semantic')
10801084
self._additional_exprs[-1].append(self._visit(assign))
10811085
return self._visit(tmp_var)
10821086

@@ -1714,7 +1718,10 @@ def _assign_GeneratorComprehension(self, lhs_name, expr):
17141718
# Use _visit_Assign to create the requested iterator with the correct type
17151719
# The result of this operation is not stored, it is just used to declare
17161720
# iterator with the correct dtype to allow correct dtype deductions later
1717-
self._visit(Assign(iterator, iterator_rhs, python_ast=expr.python_ast))
1721+
pyccel_stage.set_stage('syntactic')
1722+
syntactic_assign = Assign(iterator, iterator_rhs, python_ast=expr.python_ast)
1723+
pyccel_stage.set_stage('semantic')
1724+
self._visit(syntactic_assign)
17181725

17191726
loop_elem = loop.body.body[0]
17201727

@@ -1724,9 +1731,13 @@ def _assign_GeneratorComprehension(self, lhs_name, expr):
17241731
gens = set(loop_elem.get_attribute_nodes(GeneratorComprehension))
17251732
if len(gens)==1:
17261733
gen = gens.pop()
1734+
pyccel_stage.set_stage('syntactic')
17271735
assert isinstance(gen.lhs, PyccelSymbol) and gen.lhs.is_temp
17281736
gen_lhs = self.scope.get_new_name() if gen.lhs.is_temp else gen.lhs
1729-
assign = self._visit(Assign(gen_lhs, gen, python_ast=gen.python_ast))
1737+
syntactic_assign = Assign(gen_lhs, gen, python_ast=gen.python_ast)
1738+
pyccel_stage.set_stage('semantic')
1739+
assign = self._visit(syntactic_assign)
1740+
17301741
new_expr.append(assign)
17311742
loop.substitute(gen, assign.lhs)
17321743
loop_elem = loop.body.body[0]
@@ -1857,7 +1868,9 @@ def _get_indexed_type(self, base, args, expr):
18571868
if isinstance(base, PyccelFunctionDef) and base.cls_name is TypingFinal:
18581869
syntactic_annotation = args[0]
18591870
if not isinstance(syntactic_annotation, SyntacticTypeAnnotation):
1871+
pyccel_stage.set_stage('syntactic')
18601872
syntactic_annotation = SyntacticTypeAnnotation(dtype=syntactic_annotation)
1873+
pyccel_stage.set_stage('semantic')
18611874
annotation = self._visit(syntactic_annotation)
18621875
for t in annotation.type_list:
18631876
t.is_const = True
@@ -1887,7 +1900,9 @@ def _get_indexed_type(self, base, args, expr):
18871900
if len(args) == 2 and args[1] is LiteralEllipsis():
18881901
syntactic_annotation = args[0]
18891902
if not isinstance(syntactic_annotation, SyntacticTypeAnnotation):
1903+
pyccel_stage.set_stage('syntactic')
18901904
syntactic_annotation = SyntacticTypeAnnotation(dtype=syntactic_annotation)
1905+
pyccel_stage.set_stage('semantic')
18911906
internal_datatypes = self._visit(syntactic_annotation)
18921907
type_annotations = []
18931908
if dtype_cls is PythonTupleFunction:
@@ -1925,14 +1940,16 @@ def _visit(self, expr):
19251940
19261941
Parameters
19271942
----------
1928-
expr : pyccel.ast.basic.PyccelAstNode
1943+
expr : pyccel.ast.basic.PyccelAstNode | PyccelSymbol
19291944
Object to visit of type X.
19301945
19311946
Returns
19321947
-------
19331948
pyccel.ast.basic.PyccelAstNode
19341949
AST object which is the semantic equivalent of expr.
19351950
"""
1951+
if getattr(expr, 'pyccel_staging', 'syntactic') == 'semantic':
1952+
return expr
19361953

19371954
# TODO - add settings to Errors
19381955
# - line and column
@@ -2203,8 +2220,12 @@ def _visit_FunctionCallArgument(self, expr):
22032220
value = self._visit(expr.value)
22042221
a = FunctionCallArgument(value, expr.keyword)
22052222
def generate_and_assign_temp_var():
2223+
pyccel_stage.set_stage('syntactic')
22062224
tmp_var = self.scope.get_new_name()
2207-
assign = self._visit(Assign(tmp_var, expr.value, python_ast = expr.value.python_ast))
2225+
syntactic_assign = Assign(tmp_var, expr.value, python_ast = expr.value.python_ast)
2226+
pyccel_stage.set_stage('semantic')
2227+
2228+
assign = self._visit(syntactic_assign)
22082229
self._additional_exprs[-1].append(assign)
22092230
return FunctionCallArgument(self._visit(tmp_var))
22102231
if isinstance(value, (PyccelArithmeticOperator, PyccelInternalFunction)) and value.rank:
@@ -3663,12 +3684,15 @@ def _visit_FunctionalFor(self, expr):
36633684
def _visit_GeneratorComprehension(self, expr):
36643685
lhs = self.check_for_variable(expr.lhs)
36653686
if lhs is None:
3687+
pyccel_stage.set_stage('syntactic')
36663688
if expr.lhs.is_temp:
36673689
lhs = PyccelSymbol(self.scope.get_new_name(), is_temp=True)
36683690
else:
36693691
lhs = expr.lhs
3692+
syntactic_assign = Assign(lhs, expr, python_ast=expr.python_ast)
3693+
pyccel_stage.set_stage('semantic')
36703694

3671-
creation = self._visit(Assign(lhs, expr, python_ast=expr.python_ast))
3695+
creation = self._visit(syntactic_assign)
36723696
self._additional_exprs[-1].append(creation)
36733697
return self.get_variable(lhs)
36743698
else:
@@ -3774,9 +3798,13 @@ def _visit_Return(self, expr):
37743798
v = o.var
37753799
if not (isinstance(r, PyccelSymbol) and r == (v.name if isinstance(v, Variable) else v)):
37763800
# Create a syntactic object to visit
3801+
pyccel_stage.set_stage('syntactic')
37773802
if isinstance(v, Variable):
37783803
v = PyccelSymbol(v.name)
3779-
a = self._visit(Assign(v, r, python_ast=expr.python_ast))
3804+
syntactic_assign = Assign(v, r, python_ast=expr.python_ast)
3805+
pyccel_stage.set_stage('semantic')
3806+
3807+
a = self._visit(syntactic_assign)
37803808
assigns.append(a)
37813809
if isinstance(a, ConstructorCall):
37823810
a.cls_variable.is_temp = False
@@ -3871,6 +3899,7 @@ def unpack(ann):
38713899
templates = {t: v for t,v in templates.items() if t in used_type_names}
38723900

38733901
# Create new temparary templates for the arguments with a Union data type.
3902+
pyccel_stage.set_stage('syntactic')
38743903
tmp_templates = {}
38753904
new_expr_args = []
38763905
for a in expr.arguments:
@@ -3889,6 +3918,7 @@ def unpack(ann):
38893918
value=a.value, kwonly=a.is_kwonly, annotation=dtype_symb))
38903919
else:
38913920
new_expr_args.append(a)
3921+
pyccel_stage.set_stage('semantic')
38923922

38933923
templates.update(tmp_templates)
38943924
template_combinations = list(product(*[v.type_list for v in templates.values()]))
@@ -4513,10 +4543,15 @@ def _visit_MacroFunction(self, expr):
45134543
for hd in header:
45144544
for i,_ in enumerate(hd.dtypes):
45154545
self.scope.insert_symbol(f'arg_{i}')
4516-
arguments = [FunctionDefArgument(self._visit(AnnotatedPyccelSymbol(f'arg_{i}', annotation = arg))[0]) \
4546+
pyccel_stage.set_stage('syntactic')
4547+
syntactic_args = [AnnotatedPyccelSymbol(f'arg_{i}', annotation = arg) \
45174548
for i, arg in enumerate(hd.dtypes)]
4518-
results = [FunctionDefResult(self._visit(AnnotatedPyccelSymbol(f'out_{i}', annotation = arg))[0]) \
4549+
syntactic_results = [AnnotatedPyccelSymbol(f'out_{i}', annotation = arg) \
45194550
for i, arg in enumerate(hd.results)]
4551+
pyccel_stage.set_stage('semantic')
4552+
4553+
arguments = [FunctionDefArgument(self._visit(a)[0]) for a in syntactic_args]
4554+
results = [FunctionDefResult(self._visit(r)[0]) for r in syntactic_results]
45204555
interfaces.append(FunctionDef(f_name, arguments, results, []))
45214556

45224557
# TODO -> Said: must handle interface
@@ -4600,8 +4635,12 @@ def _visit_NumpyNonZero(self, func_call):
46004635
# expr is a FunctionCall
46014636
arg = func_call_args[0].value
46024637
if not isinstance(arg, Variable):
4638+
pyccel_stage.set_stage('syntactic')
46034639
new_symbol = PyccelSymbol(self.scope.get_new_name())
4604-
creation = self._visit(Assign(new_symbol, arg, python_ast=func_call.python_ast))
4640+
syntactic_assign = Assign(new_symbol, arg, python_ast=func_call.python_ast)
4641+
pyccel_stage.set_stage('semantic')
4642+
4643+
creation = self._visit(syntactic_assign)
46054644
self._additional_exprs[-1].append(creation)
46064645
arg = self._visit(new_symbol)
46074646
return NumpyWhere(arg)

0 commit comments

Comments
 (0)