Skip to content

Commit 5bd938d

Browse files
authored
Add support for set method union (pyccel#2015)
Add support for set method `union`. Fixes pyccel#1753 **Commit Summary** - Add a class to represent `SetUnion` - Remove unused `prefix_code` variable from `CCodePrinter._print_Assign` - Add C printing for union by adding a function to the extensions - Add Fortran printing for union by using `copy` and `update` - Correct error about unknown shape which is irrelevant for objects whose shape can change - Add tests for `union` - Fix AST tree for `DottedName`
1 parent 6607b0d commit 5bd938d

File tree

14 files changed

+264
-42
lines changed

14 files changed

+264
-42
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ All notable changes to this project will be documented in this file.
3838
- #1917 : Add C and Fortran support for set method `add()`.
3939
- #1918 : Add support for set method `clear()`.
4040
- #1918 : Add support for set method `copy()`.
41+
- #1753 : Add support for set method `union()`.
4142
- #1936 : Add missing C output for inline decorator example in documentation
4243
- #1937 : Optimise `pyccel.ast.basic.PyccelAstNode.substitute` method.
4344
- #1544 : Add support for `typing.TypeAlias`.

docs/builtin-functions.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ Python contains a limited number of builtin functions defined [here](https://doc
111111
| `remove` | Python-only |
112112
| `symmetric_difference` | No |
113113
| `symmetric_difference_update` | No |
114-
| `union` | No |
114+
| **`union`** | Yes |
115115
| `update` | Python-only |
116116

117117
## Dictionary methods

pyccel/ast/builtin_methods/set_methods.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
'SetMethod',
2222
'SetPop',
2323
'SetRemove',
24+
'SetUnion',
2425
'SetUpdate'
2526
)
2627

@@ -227,6 +228,36 @@ class SetUpdate(SetMethod):
227228
"""
228229
__slots__ = ()
229230
name = 'update'
231+
_shape = None
232+
_class_type = VoidType()
230233

231234
def __init__(self, set_obj, iterable) -> None:
232235
super().__init__(set_obj, iterable)
236+
237+
#==============================================================================
238+
class SetUnion(SetMethod):
239+
"""
240+
Represents a call to the set method .union.
241+
242+
Represents a call to the set method .union. This method builds a new set
243+
by including all elements which appear in at least one of the iterables
244+
(the set object and the arguments).
245+
246+
Parameters
247+
----------
248+
set_obj : TypedAstNode
249+
The set object which the method is called from.
250+
*others : TypedAstNode
251+
The iterables which will be combined with this set.
252+
"""
253+
__slots__ = ('_other','_class_type', '_shape')
254+
name = 'union'
255+
256+
def __init__(self, set_obj, *others):
257+
self._class_type = set_obj.class_type
258+
element_type = self._class_type.element_type
259+
for o in others:
260+
if element_type != o.class_type.element_type:
261+
raise TypeError(f"Argument of type {o.class_type} cannot be used to build set of type {self._class_type}")
262+
self._shape = (None,)*self._class_type.rank
263+
super().__init__(set_obj, *others)

pyccel/ast/class_defs.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
This module contains all types which define a python class which is automatically recognised by pyccel
77
"""
88

9-
from pyccel.ast.builtin_methods.set_methods import SetAdd, SetClear, SetCopy, SetPop, SetRemove, SetDiscard, SetUpdate
9+
from pyccel.ast.builtin_methods.set_methods import (SetAdd, SetClear, SetCopy, SetPop, SetRemove,
10+
SetDiscard, SetUpdate, SetUnion)
1011
from pyccel.ast.builtin_methods.list_methods import (ListAppend, ListInsert, ListPop,
1112
ListClear, ListExtend, ListRemove,
1213
ListCopy, ListSort)
@@ -163,6 +164,7 @@
163164
PyccelFunctionDef('discard', func_class = SetDiscard),
164165
PyccelFunctionDef('pop', func_class = SetPop),
165166
PyccelFunctionDef('remove', func_class = SetRemove),
167+
PyccelFunctionDef('union', func_class = SetUnion),
166168
PyccelFunctionDef('update', func_class = SetUpdate),
167169
])
168170

pyccel/ast/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1970,7 +1970,7 @@ def __init__(self, func, args, current_function=None):
19701970
self._interface = None
19711971
self._funcdef = func
19721972
self._arguments = tuple(args)
1973-
self._func_name = func
1973+
self._func_name = getattr(func, 'name', func)
19741974
super().__init__()
19751975
return
19761976

pyccel/ast/variable.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -561,7 +561,7 @@ class DottedName(PyccelAstNode):
561561
pyccel.stdlib.parallel
562562
"""
563563
__slots__ = ('_name',)
564-
_attribute_nodes = ()
564+
_attribute_nodes = ('_name',)
565565

566566
def __new__(cls, *args):
567567
if len(args) == 1:

pyccel/codegen/printing/ccode.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -992,7 +992,7 @@ def _print_Import(self, expr):
992992
else:
993993
source = expr.source
994994
if isinstance(source, DottedName):
995-
source = source.name[-1]
995+
source = source.name[-1].python_value
996996
else:
997997
source = self._print(source)
998998
if source.startswith('stc/') or source in import_header_guard_prefix:
@@ -2228,22 +2228,21 @@ def _print_AugAssign(self, expr):
22282228
return f'{lhs_code} {op}= {rhs_code};\n'
22292229

22302230
def _print_Assign(self, expr):
2231-
prefix_code = ''
22322231
lhs = expr.lhs
22332232
rhs = expr.rhs
22342233
if isinstance(rhs, FunctionCall) and isinstance(rhs.class_type, TupleType):
22352234
self._temporary_args = [ObjectAddress(a) for a in lhs]
2236-
return prefix_code+'{};\n'.format(self._print(rhs))
2235+
return f'{self._print(rhs)};\n'
22372236
# Inhomogenous tuples are unravelled and therefore do not exist in the c printer
22382237
if isinstance(rhs, (NumpyArray, PythonTuple)):
2239-
return prefix_code+self.copy_NumpyArray_Data(expr)
2238+
return self.copy_NumpyArray_Data(expr)
22402239
if isinstance(rhs, (NumpyFull)):
2241-
return prefix_code+self.arrayFill(expr)
2240+
return self.arrayFill(expr)
22422241
lhs = self._print(expr.lhs)
22432242
if isinstance(rhs, (PythonList, PythonSet, PythonDict)):
2244-
return prefix_code+self.init_stc_container(rhs, expr)
2243+
return self.init_stc_container(rhs, expr)
22452244
rhs = self._print(expr.rhs)
2246-
return prefix_code+f'{lhs} = {rhs};\n'
2245+
return f'{lhs} = {rhs};\n'
22472246

22482247
def _print_AliasAssign(self, expr):
22492248
lhs_var = expr.lhs
@@ -2632,6 +2631,18 @@ def _print_SetCopy(self, expr):
26322631
set_var = self._print(expr.set_variable)
26332632
return f'{var_type}_clone({set_var})'
26342633

2634+
def _print_SetUnion(self, expr):
2635+
assign_base = expr.get_direct_user_nodes(lambda n: isinstance(n, Assign))
2636+
if not assign_base:
2637+
errors.report("The result of the union call must be saved into a variable",
2638+
severity='error', symbol=expr)
2639+
class_type = expr.set_variable.class_type
2640+
var_type = self.get_c_type(class_type)
2641+
self.add_import(Import('Set_extensions', AsName(VariableTypeAnnotation(class_type), var_type)))
2642+
set_var = self._print(ObjectAddress(expr.set_variable))
2643+
args = ', '.join([str(len(expr.args)), *(self._print(ObjectAddress(a)) for a in expr.args)])
2644+
return f'{var_type}_union({set_var}, {args})'
2645+
26352646
#=================== MACROS ==================
26362647

26372648
def _print_MacroShape(self, expr):

pyccel/codegen/printing/fcode.py

Lines changed: 35 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
from pyccel.ast.builtins import PythonTuple, DtypePrecisionToCastFunction
2525
from pyccel.ast.builtins import PythonBool, PythonList, PythonSet
2626

27+
from pyccel.ast.builtin_methods.set_methods import SetUnion
28+
2729
from pyccel.ast.core import FunctionDef, FunctionDefArgument, FunctionDefResult
2830
from pyccel.ast.core import SeparatorComment, Comment
2931
from pyccel.ast.core import ConstructorCall
@@ -262,7 +264,7 @@ def __init__(self, filename, prefix_module = None):
262264
self._constantImports = {}
263265
self._current_class = None
264266

265-
self._additional_code = None
267+
self._additional_code = ''
266268

267269
self.prefix_module = prefix_module
268270

@@ -425,8 +427,6 @@ def _handle_inline_func_call(self, expr, assign_lhs = None):
425427
# Everything before the return node needs handling before the line
426428
# which calls the inline function is executed
427429
code = self._print(body)
428-
if (not self._additional_code):
429-
self._additional_code = ''
430430
self._additional_code += code
431431

432432
# Collect statements from results to return object
@@ -888,7 +888,7 @@ def _print_Import(self, expr):
888888

889889
source = expr.source
890890
if isinstance(source, DottedName):
891-
source = source.name[-1]
891+
source = source.name[-1].python_value
892892
elif isinstance(source, LiteralString):
893893
source = source.python_value
894894
else:
@@ -1276,14 +1276,12 @@ def _print_Constant(self, expr):
12761276
def _print_DottedVariable(self, expr):
12771277
if isinstance(expr.lhs, FunctionCall):
12781278
base = expr.lhs.funcdef.results[0].var
1279-
if (not self._additional_code):
1280-
self._additional_code = ''
12811279
var_name = self.scope.get_new_name()
12821280
var = base.clone(var_name)
12831281

12841282
self.scope.insert_variable(var)
12851283

1286-
self._additional_code = self._additional_code + self._print(Assign(var,expr.lhs)) + '\n'
1284+
self._additional_code += self._print(Assign(var,expr.lhs)) + '\n'
12871285
return self._print(var) + '%' +self._print(expr.name)
12881286
else:
12891287
return self._print(expr.lhs) + '%' +self._print(expr.name)
@@ -1337,6 +1335,31 @@ def _print_SetCopy(self, expr):
13371335
type_name = self._print(expr.class_type)
13381336
return f'{type_name}({var_code})'
13391337

1338+
def _print_SetUnion(self, expr):
1339+
assign_base = expr.get_direct_user_nodes(lambda n: isinstance(n, Assign))
1340+
var = expr.set_variable
1341+
if not assign_base:
1342+
result = self._print(self.scope.get_temporary_variable(var))
1343+
else:
1344+
result = self._print(assign_base[0].lhs)
1345+
expr_type = var.class_type
1346+
var_code = self._print(expr.set_variable)
1347+
type_name = self._print(expr_type)
1348+
self.add_import(self._build_gFTL_extension_module(expr_type))
1349+
args_insert = []
1350+
for arg in expr.args:
1351+
a = self._print(arg)
1352+
if arg.class_type == expr_type:
1353+
args_insert.append(f'call {result} % merge({a})\n')
1354+
else:
1355+
errors.report(PYCCEL_RESTRICTION_TODO, severity = 'error', symbol = expr)
1356+
code = f'{result} = {type_name}({var_code})\n' + ''.join(args_insert)
1357+
if assign_base:
1358+
return code
1359+
else:
1360+
self._additional_code += code
1361+
return result
1362+
13401363
#========================== Numpy Elements ===============================#
13411364

13421365
def _print_NumpySum(self, expr):
@@ -1675,12 +1698,10 @@ def _print_NumpyRand(self, expr):
16751698
errors.report(FORTRAN_ALLOCATABLE_IN_EXPRESSION,
16761699
symbol=expr, severity='fatal')
16771700

1678-
if (not self._additional_code):
1679-
self._additional_code = ''
16801701
var = self.scope.get_temporary_variable(expr.dtype, memory_handling = 'stack',
16811702
shape = expr.shape)
16821703

1683-
self._additional_code = self._additional_code + self._print(Assign(var,expr)) + '\n'
1704+
self._additional_code += self._print(Assign(var,expr)) + '\n'
16841705
return self._print(var)
16851706

16861707
def _print_NumpyRandint(self, expr):
@@ -1978,9 +1999,9 @@ def _print_CodeBlock(self, expr):
19781999
body_stmts = []
19792000
for b in body_exprs :
19802001
line = self._print(b)
1981-
if (self._additional_code):
2002+
if self._additional_code:
19822003
body_stmts.append(self._additional_code)
1983-
self._additional_code = None
2004+
self._additional_code = ''
19842005
body_stmts.append(line)
19852006
return ''.join(self._print(b) for b in body_stmts)
19862007

@@ -2000,7 +2021,7 @@ def _print_NumpyReal(self, expr):
20002021
def _print_Assign(self, expr):
20012022
rhs = expr.rhs
20022023

2003-
if isinstance(rhs, FunctionCall):
2024+
if isinstance(rhs, (FunctionCall, SetUnion)):
20042025
return self._print(rhs)
20052026

20062027
lhs_code = self._print(expr.lhs)
@@ -3397,17 +3418,13 @@ def _print_FunctionCall(self, expr):
33973418
args = args[1:]
33983419
if isinstance(class_variable, FunctionCall):
33993420
base = class_variable.funcdef.results[0].var
3400-
if (not self._additional_code):
3401-
self._additional_code = ''
34023421
var = self.scope.get_temporary_variable(base)
34033422

3404-
self._additional_code = self._additional_code + self._print(Assign(var, class_variable)) + '\n'
3423+
self._additional_code += self._print(Assign(var, class_variable)) + '\n'
34053424
f_name = f'{self._print(var)} % {f_name}'
34063425
else:
34073426
f_name = f'{self._print(class_variable)} % {f_name}'
34083427

3409-
if (not self._additional_code):
3410-
self._additional_code = ''
34113428
if parent_assign:
34123429
lhs = parent_assign[0].lhs
34133430
if len(func_results) == 1:

pyccel/codegen/printing/pycode.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -524,7 +524,8 @@ def _print_EmptyNode(self, expr):
524524
return ''
525525

526526
def _print_DottedName(self, expr):
527-
return '.'.join(self._print(n) for n in expr.name)
527+
# A DottedName can only contain LiteralStrings or PyccelSymbols at the printing stage
528+
return '.'.join(str(n) for n in expr.name)
528529

529530
def _print_FunctionCall(self, expr):
530531
func = expr.funcdef

pyccel/parser/semantic.py

Lines changed: 49 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
PythonTuple, Lambda, PythonMap)
3030

3131
from pyccel.ast.builtin_methods.list_methods import ListMethod, ListAppend
32-
from pyccel.ast.builtin_methods.set_methods import SetMethod, SetAdd
32+
from pyccel.ast.builtin_methods.set_methods import SetAdd, SetUnion
3333

3434
from pyccel.ast.core import Comment, CommentBlock, Pass
3535
from pyccel.ast.core import If, IfSection
@@ -1095,6 +1095,7 @@ def _handle_function(self, expr, func, args, *, is_method = False, use_build_fun
10951095
new_expr = FunctionCall(func, args)
10961096
new_expr.set_current_ast(expr.python_ast)
10971097
pyccel_stage.set_stage('semantic')
1098+
new_expr.set_current_user_node(expr.current_user_node)
10981099
expr = new_expr
10991100
return getattr(self, annotation_method)(expr)
11001101

@@ -1566,7 +1567,7 @@ def _assign_lhs_variable(self, lhs, d_var, rhs, new_expressions, is_augassign,ar
15661567
know_lhs_shape = (lhs.rank == 0) or all(sh is not None for sh in lhs.alloc_shape) \
15671568
or isinstance(lhs.class_type, StringType)
15681569

1569-
if not know_lhs_shape:
1570+
if isinstance(class_type, (NumpyNDArrayType, HomogeneousTupleType)) and not know_lhs_shape:
15701571
msg = f"Cannot infer shape of right-hand side for expression {lhs} = {rhs}"
15711572
errors.report(PYCCEL_RESTRICTION_TODO+'\n'+msg,
15721573
bounding_box=(self.current_ast_node.lineno, self.current_ast_node.col_offset),
@@ -3158,13 +3159,7 @@ def _visit_Assign(self, expr):
31583159
symbol=expr, severity='error')
31593160

31603161
# Checking for the result of _visit_ListExtend
3161-
if isinstance(rhs, For) or (isinstance(rhs, CodeBlock) and
3162-
isinstance(rhs.body[0], (ListMethod, SetMethod))):
3163-
return rhs
3164-
if isinstance(rhs, ConstructorCall):
3165-
return rhs
3166-
3167-
elif isinstance(rhs, CodeBlock) and len(rhs.body)>1 and isinstance(rhs.body[1], FunctionalFor):
3162+
if isinstance(rhs, (For, CodeBlock, ConstructorCall)):
31683163
return rhs
31693164

31703165
elif isinstance(rhs, FunctionCall):
@@ -5032,7 +5027,7 @@ def _build_SetUpdate(self, expr):
50325027
elements of the iterable. If not, it attempts to construct a syntactic `For`
50335028
loop to iterate over the iterable object and added its elements to the set
50345029
object. Finally, it passes to a `_visit()` call for semantic parsing.
5035-
5030+
50365031
Parameters
50375032
----------
50385033
expr : DottedName
@@ -5067,3 +5062,47 @@ def _build_SetUpdate(self, expr):
50675062
pyccel_stage.set_stage('semantic')
50685063
return self._visit(for_obj)
50695064

5065+
def _build_SetUnion(self, expr):
5066+
"""
5067+
Method to navigate the syntactic DottedName node of a `set.union()` call.
5068+
5069+
The purpose of this `_build` method is to construct new nodes from a syntactic
5070+
DottedName node. It creates a SetUnion node if the type of the arguments matches
5071+
the type of the original set. Otherwise it uses `set.copy` and `set.update` to
5072+
handle iterators.
5073+
5074+
Parameters
5075+
----------
5076+
expr : DottedName
5077+
The syntactic DottedName node that represent the call to `.union()`.
5078+
5079+
Returns
5080+
-------
5081+
SetUnion | CodeBlock
5082+
The nodes describing the union operator.
5083+
"""
5084+
syntactic_set_obj = expr.name[0]
5085+
syntactic_args = expr.name[1].args
5086+
set_obj = self._visit(expr.name[0])
5087+
args = [self._visit(a.value) for a in expr.name[1].args]
5088+
class_type = set_obj.class_type
5089+
if all(a.class_type == class_type for a in args):
5090+
return SetUnion(set_obj, *args)
5091+
else:
5092+
element_type = class_type.element_type
5093+
if any(a.class_type.element_type != element_type for a in args):
5094+
errors.report(("Containers containing objects of a different type cannot be used as "
5095+
f"arguments to {class_type}.union"),
5096+
severity='fatal', symbol=expr)
5097+
5098+
lhs = expr.get_user_nodes(Assign)[0].lhs
5099+
pyccel_stage.set_stage('syntactic')
5100+
body = [Assign(lhs, DottedName(syntactic_set_obj, FunctionCall('copy', ())),
5101+
python_ast = expr.python_ast)]
5102+
update_calls = [DottedName(lhs, FunctionCall('update', (s_a,))) for s_a in syntactic_args]
5103+
for c in update_calls:
5104+
c.set_current_ast(expr.python_ast)
5105+
body += [Assign(PyccelSymbol('_', is_temp=True), c, python_ast = expr.python_ast)
5106+
for c in update_calls]
5107+
pyccel_stage.set_stage('semantic')
5108+
return CodeBlock([self._visit(b) for b in body])

0 commit comments

Comments
 (0)