Skip to content

Commit bac6b28

Browse files
authored
Support sets as arguments to functions (pyccel#2031)
Improve the C-Python wrapper to allow passing sets as arguments. Fixes pyccel#1663 **Commit Summary** - Add documentation for set annotation - Add description for C-Python API set manipulation functions - Add support for constant set arguments to C functions - Remove include guards from `Set_extensions.h` to allow extensions to multiple set types - Add a test for a function with a constant set argument
1 parent 486e6de commit bac6b28

File tree

8 files changed

+285
-58
lines changed

8 files changed

+285
-58
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ All notable changes to this project will be documented in this file.
3737
- #1689 : Add C and Fortran support for list method `append()`.
3838
- #1876 : Add C support for indexing lists.
3939
- #1690 : Add C support for list method `pop()`.
40+
- #1663 : Add C support for sets as constant arguments.
4041
- #1664 : Add C support for returning sets from functions.
4142
- #2023 : Add support for iterating over a `set`.
4243
- #1877 : Add C and Fortran Support for set method `pop()`.

docs/type_annotations.md

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,15 @@ a : set[int] = {1, 2}
8282
b : set[bool] = {False, True}
8383
c : set[float] = {}
8484
```
85-
So far sets can be declared as local variables or as results of functions.
85+
Sets can be declared as local variables, arguments or results of functions translated to C. An argument can be marked as constant using a string annotation or (in a module) using the `Final` qualifier:
86+
```python
87+
def f(a : 'const set[int]'):
88+
pass
89+
90+
from typing import Final
91+
def g(b : Final[set[bool]]):
92+
pass
93+
```
8694

8795
## Dictionaries
8896

pyccel/ast/cwrapper.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1143,7 +1143,7 @@ def C_to_Python(c_object):
11431143
results = [FunctionDefResult(Variable(CNativeInt(), 'i'))])
11441144

11451145
#-------------------------------------------------------------------
1146-
# Set functions
1146+
# Set functions
11471147
#-------------------------------------------------------------------
11481148

11491149
# https://docs.python.org/3/c-api/set.html#c.PySet_New
@@ -1159,6 +1159,36 @@ def C_to_Python(c_object):
11591159
results = [FunctionDefResult(Variable(PythonNativeInt(), 'i'))],
11601160
body = [])
11611161

1162+
# https://docs.python.org/3/c-api/set.html#c.PySet_Check
1163+
PySet_Check = FunctionDef(name = 'PySet_Check',
1164+
arguments = [FunctionDefArgument(Variable(PyccelPyObject(), 'set', memory_handling='alias'))],
1165+
results = [FunctionDefResult(Variable(CNativeInt(), 'i'))],
1166+
body = [])
1167+
1168+
# https://docs.python.org/3/c-api/set.html#c.PySet_Size
1169+
PySet_Size = FunctionDef(name = 'PySet_Size',
1170+
arguments = [FunctionDefArgument(Variable(PyccelPyObject(), 'set', memory_handling='alias'))],
1171+
results = [FunctionDefResult(Variable(PythonNativeInt(), 'i'))],
1172+
body = [])
1173+
1174+
# https://docs.python.org/3/c-api/object.html#c.PyObject_GetIter
1175+
PySet_GetIter = FunctionDef(name = 'PyObject_GetIter',
1176+
body = [],
1177+
arguments = [FunctionDefArgument(Variable(PyccelPyObject(), name='iter', memory_handling='alias'))],
1178+
results = [FunctionDefResult(Variable(PyccelPyObject(), name='o', memory_handling='alias'))])
1179+
1180+
# https://docs.python.org/3/c-api/set.html#c.PySet_Clear
1181+
PySet_Clear = FunctionDef(name = 'PySet_Clear',
1182+
body = [],
1183+
arguments = [FunctionDefArgument(Variable(PyccelPyObject(), name='set', memory_handling='alias'))],
1184+
results = [FunctionDefResult(Variable(PythonNativeInt(), 'i'))])
1185+
1186+
# https://docs.python.org/3/c-api/iter.html#c.PyIter_Check
1187+
PyIter_Next = FunctionDef(name = 'PyIter_Next',
1188+
body = [],
1189+
arguments = [FunctionDefArgument(Variable(PyccelPyObject(), name='iter', memory_handling='alias'))],
1190+
results = [FunctionDefResult(Variable(PyccelPyObject(), name='o', memory_handling='alias'))])
1191+
11621192

11631193
# Functions definitions are defined in pyccel/stdlib/cwrapper/cwrapper.c
11641194
check_type_registry = {

pyccel/codegen/printing/ccode.py

Lines changed: 42 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,34 @@ def __init__(self, filename, prefix_module = None):
325325
self._current_module = None
326326
self._in_header = False
327327

328+
def sort_imports(self, imports):
329+
"""
330+
Sort imports to avoid any errors due to bad ordering.
331+
332+
Sort imports. This is important so that types exist before they are used to create
333+
container types. E.g. it is important that complex or inttypes be imported before
334+
vec_int or vec_double_complex is declared.
335+
336+
Parameters
337+
----------
338+
imports : list[Import]
339+
A list of the imports.
340+
341+
Returns
342+
-------
343+
list[Import]
344+
A sorted list of the imports.
345+
"""
346+
import_src = [str(i.source) for i in imports]
347+
stc_imports = [i for i in import_src if i.startswith('stc/')]
348+
dependent_imports = [i for i in import_src if i in import_header_guard_prefix]
349+
non_stc_imports = [i for i in import_src if i not in chain(stc_imports, dependent_imports)]
350+
stc_imports.sort()
351+
dependent_imports.sort()
352+
non_stc_imports.sort()
353+
sorted_imports = [imports[import_src.index(name)] for name in chain(non_stc_imports, stc_imports, dependent_imports)]
354+
return sorted_imports
355+
328356
def _format_code(self, lines):
329357
return self.indent_code(lines)
330358

@@ -364,13 +392,16 @@ def is_c_pointer(self, a):
364392
return True
365393
if isinstance(a, FunctionCall):
366394
a = a.funcdef.results[0].var
367-
if isinstance(getattr(a, 'dtype', None), CustomDataType) and a.is_argument:
368-
return True
369-
370395
if not isinstance(a, Variable):
371396
return False
372-
return (a.is_alias and not isinstance(a.class_type, (HomogeneousTupleType, NumpyNDArrayType))) \
373-
or a.is_optional or \
397+
if isinstance(a.class_type, (HomogeneousTupleType, NumpyNDArrayType)):
398+
return a.is_optional or any(a is bi for b in self._additional_args for bi in b)
399+
400+
if isinstance(a.class_type, (CustomDataType, HomogeneousContainerType, DictType)) \
401+
and a.is_argument and not a.is_const:
402+
return True
403+
404+
return a.is_alias or a.is_optional or \
374405
any(a is bi for b in self._additional_args for bi in b)
375406

376407
#========================== Numpy Elements ===============================#
@@ -745,7 +776,7 @@ def _print_PythonMinMax(self, expr):
745776
op = '<' if isinstance(expr, PythonMin) else '>'
746777
return f"({arg1} {op} {arg2} ? {arg1} : {arg2})"
747778
elif len(arg) > 2 and isinstance(arg.dtype.primitive_type, (PrimitiveFloatingPointType, PrimitiveIntegerType)):
748-
key = self.get_declare_type(arg[0])
779+
key = self.get_c_type(arg[0].class_type)
749780
self.add_import(Import('stc/common', AsName(VariableTypeAnnotation(arg.dtype), key)))
750781
self.add_import(Import('Common_extensions',
751782
AsName(VariableTypeAnnotation(arg.dtype), key),
@@ -851,6 +882,7 @@ def _print_ModuleHeader(self, expr):
851882

852883
# Print imports last to be sure that all additional_imports have been collected
853884
imports = [*expr.module.imports, *self._additional_imports.values()]
885+
imports = self.sort_imports(imports)
854886
imports = ''.join(self._print(i) for i in imports)
855887

856888
self._in_header = False
@@ -2658,6 +2690,7 @@ def _print_Program(self, expr):
26582690
decs = ''.join(self._print(Declare(v)) for v in variables)
26592691

26602692
imports = [*expr.imports, *self._additional_imports.values()]
2693+
imports = self.sort_imports(imports)
26612694
imports = ''.join(self._print(i) for i in imports)
26622695

26632696
self.exit_scope()
@@ -2724,18 +2757,18 @@ def _print_SetPop(self, expr):
27242757
return f'{var_type}_pop({set_var})'
27252758

27262759
def _print_SetClear(self, expr):
2727-
var_type = self.get_declare_type(expr.set_variable)
2760+
var_type = self.get_c_type(expr.set_variable.class_type)
27282761
set_var = self._print(ObjectAddress(expr.set_variable))
27292762
return f'{var_type}_clear({set_var});\n'
27302763

27312764
def _print_SetAdd(self, expr):
2732-
var_type = self.get_declare_type(expr.set_variable)
2765+
var_type = self.get_c_type(expr.set_variable.class_type)
27332766
set_var = self._print(ObjectAddress(expr.set_variable))
27342767
arg = self._print(expr.args[0])
27352768
return f'{var_type}_push({set_var}, {arg});\n'
27362769

27372770
def _print_SetCopy(self, expr):
2738-
var_type = self.get_declare_type(expr.set_variable)
2771+
var_type = self.get_c_type(expr.set_variable.class_type)
27392772
set_var = self._print(expr.set_variable)
27402773
return f'{var_type}_clone({set_var})'
27412774

0 commit comments

Comments
 (0)