24
24
from pyccel .ast .builtins import PythonTuple , DtypePrecisionToCastFunction
25
25
from pyccel .ast .builtins import PythonBool , PythonList , PythonSet
26
26
27
+ from pyccel .ast .builtin_methods .set_methods import SetUnion
28
+
27
29
from pyccel .ast .core import FunctionDef , FunctionDefArgument , FunctionDefResult
28
30
from pyccel .ast .core import SeparatorComment , Comment
29
31
from pyccel .ast .core import ConstructorCall
@@ -262,7 +264,7 @@ def __init__(self, filename, prefix_module = None):
262
264
self ._constantImports = {}
263
265
self ._current_class = None
264
266
265
- self ._additional_code = None
267
+ self ._additional_code = ''
266
268
267
269
self .prefix_module = prefix_module
268
270
@@ -425,8 +427,6 @@ def _handle_inline_func_call(self, expr, assign_lhs = None):
425
427
# Everything before the return node needs handling before the line
426
428
# which calls the inline function is executed
427
429
code = self ._print (body )
428
- if (not self ._additional_code ):
429
- self ._additional_code = ''
430
430
self ._additional_code += code
431
431
432
432
# Collect statements from results to return object
@@ -888,7 +888,7 @@ def _print_Import(self, expr):
888
888
889
889
source = expr .source
890
890
if isinstance (source , DottedName ):
891
- source = source .name [- 1 ]
891
+ source = source .name [- 1 ]. python_value
892
892
elif isinstance (source , LiteralString ):
893
893
source = source .python_value
894
894
else :
@@ -1276,14 +1276,12 @@ def _print_Constant(self, expr):
1276
1276
def _print_DottedVariable (self , expr ):
1277
1277
if isinstance (expr .lhs , FunctionCall ):
1278
1278
base = expr .lhs .funcdef .results [0 ].var
1279
- if (not self ._additional_code ):
1280
- self ._additional_code = ''
1281
1279
var_name = self .scope .get_new_name ()
1282
1280
var = base .clone (var_name )
1283
1281
1284
1282
self .scope .insert_variable (var )
1285
1283
1286
- self ._additional_code = self . _additional_code + self ._print (Assign (var ,expr .lhs )) + '\n '
1284
+ self ._additional_code += self ._print (Assign (var ,expr .lhs )) + '\n '
1287
1285
return self ._print (var ) + '%' + self ._print (expr .name )
1288
1286
else :
1289
1287
return self ._print (expr .lhs ) + '%' + self ._print (expr .name )
@@ -1337,6 +1335,31 @@ def _print_SetCopy(self, expr):
1337
1335
type_name = self ._print (expr .class_type )
1338
1336
return f'{ type_name } ({ var_code } )'
1339
1337
1338
+ def _print_SetUnion (self , expr ):
1339
+ assign_base = expr .get_direct_user_nodes (lambda n : isinstance (n , Assign ))
1340
+ var = expr .set_variable
1341
+ if not assign_base :
1342
+ result = self ._print (self .scope .get_temporary_variable (var ))
1343
+ else :
1344
+ result = self ._print (assign_base [0 ].lhs )
1345
+ expr_type = var .class_type
1346
+ var_code = self ._print (expr .set_variable )
1347
+ type_name = self ._print (expr_type )
1348
+ self .add_import (self ._build_gFTL_extension_module (expr_type ))
1349
+ args_insert = []
1350
+ for arg in expr .args :
1351
+ a = self ._print (arg )
1352
+ if arg .class_type == expr_type :
1353
+ args_insert .append (f'call { result } % merge({ a } )\n ' )
1354
+ else :
1355
+ errors .report (PYCCEL_RESTRICTION_TODO , severity = 'error' , symbol = expr )
1356
+ code = f'{ result } = { type_name } ({ var_code } )\n ' + '' .join (args_insert )
1357
+ if assign_base :
1358
+ return code
1359
+ else :
1360
+ self ._additional_code += code
1361
+ return result
1362
+
1340
1363
#========================== Numpy Elements ===============================#
1341
1364
1342
1365
def _print_NumpySum (self , expr ):
@@ -1675,12 +1698,10 @@ def _print_NumpyRand(self, expr):
1675
1698
errors .report (FORTRAN_ALLOCATABLE_IN_EXPRESSION ,
1676
1699
symbol = expr , severity = 'fatal' )
1677
1700
1678
- if (not self ._additional_code ):
1679
- self ._additional_code = ''
1680
1701
var = self .scope .get_temporary_variable (expr .dtype , memory_handling = 'stack' ,
1681
1702
shape = expr .shape )
1682
1703
1683
- self ._additional_code = self . _additional_code + self ._print (Assign (var ,expr )) + '\n '
1704
+ self ._additional_code += self ._print (Assign (var ,expr )) + '\n '
1684
1705
return self ._print (var )
1685
1706
1686
1707
def _print_NumpyRandint (self , expr ):
@@ -1978,9 +1999,9 @@ def _print_CodeBlock(self, expr):
1978
1999
body_stmts = []
1979
2000
for b in body_exprs :
1980
2001
line = self ._print (b )
1981
- if ( self ._additional_code ) :
2002
+ if self ._additional_code :
1982
2003
body_stmts .append (self ._additional_code )
1983
- self ._additional_code = None
2004
+ self ._additional_code = ''
1984
2005
body_stmts .append (line )
1985
2006
return '' .join (self ._print (b ) for b in body_stmts )
1986
2007
@@ -2000,7 +2021,7 @@ def _print_NumpyReal(self, expr):
2000
2021
def _print_Assign (self , expr ):
2001
2022
rhs = expr .rhs
2002
2023
2003
- if isinstance (rhs , FunctionCall ):
2024
+ if isinstance (rhs , ( FunctionCall , SetUnion ) ):
2004
2025
return self ._print (rhs )
2005
2026
2006
2027
lhs_code = self ._print (expr .lhs )
@@ -3397,17 +3418,13 @@ def _print_FunctionCall(self, expr):
3397
3418
args = args [1 :]
3398
3419
if isinstance (class_variable , FunctionCall ):
3399
3420
base = class_variable .funcdef .results [0 ].var
3400
- if (not self ._additional_code ):
3401
- self ._additional_code = ''
3402
3421
var = self .scope .get_temporary_variable (base )
3403
3422
3404
- self ._additional_code = self . _additional_code + self ._print (Assign (var , class_variable )) + '\n '
3423
+ self ._additional_code += self ._print (Assign (var , class_variable )) + '\n '
3405
3424
f_name = f'{ self ._print (var )} % { f_name } '
3406
3425
else :
3407
3426
f_name = f'{ self ._print (class_variable )} % { f_name } '
3408
3427
3409
- if (not self ._additional_code ):
3410
- self ._additional_code = ''
3411
3428
if parent_assign :
3412
3429
lhs = parent_assign [0 ].lhs
3413
3430
if len (func_results ) == 1 :
0 commit comments