Skip to content

Commit 0f25459

Browse files
authored
Add Fortran support for set initialisation (pyccel#2004)
Add Fortran support for set initialisation in a similar way to lists. Fixes pyccel#1658
1 parent f3b4bde commit 0f25459

File tree

9 files changed

+257
-87
lines changed

9 files changed

+257
-87
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ All notable changes to this project will be documented in this file.
2828
- #1659 : Add the appropriate C language equivalent for declaring a Python `set` container using the STC library.
2929
- #1944 : Add the appropriate C language equivalent for declaring a Python `dict` container using the STC library.
3030
- #1657 : Add the appropriate Fortran language equivalent for declaring a Python `list` container using the gFTL library.
31+
- #1658 : Add the appropriate Fortran language equivalent for declaring a Python `set` container using the gFTL library.
32+
- #1944 : Add the appropriate Fortran language equivalent for declaring a Python `dict` container using the gFTL library.
3133
- #1874 : Add C and Fortran support for the `len()` function for the `list` container.
3234
- #1875 : Add C and Fortran support for the `len()` function for the `set` container.
3335
- #1908 : Add C and Fortran support for the `len()` function for the `dict` container.

docs/builtin-functions.md

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,27 @@ Python contains a limited number of builtin functions defined [here](https://doc
9292
| `reverse` | No |
9393
| `sort` | Python-only |
9494

95+
## Set methods
96+
97+
| Method | Supported |
98+
|----------|-----------|
99+
| `add` | Python-only |
100+
| `clear` | Python-only |
101+
| `copy` | Python-only |
102+
| `difference` | No |
103+
| `difference_update` | No |
104+
| `discard` | Python-only |
105+
| `intersection` | No |
106+
| `intersection_update` | No |
107+
| `isdisjoint` | No |
108+
| `issubset` | No |
109+
| `issuperset` | No |
110+
| `pop` | Python-only |
111+
| `remove` | Python-only |
112+
| `symmetric_difference` | No |
113+
| `symmetric_difference_update` | No |
114+
| `union` | No |
115+
| `update` | Python-only |
95116

96117
## Dictionary methods
97118

pyccel/ast/builtins.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -730,8 +730,10 @@ class PythonListFunction(PyccelFunction):
730730
__slots__ = ('_class_type', '_shape')
731731
_attribute_nodes = ()
732732

733-
def __new__(cls, arg):
734-
if isinstance(arg, PythonList):
733+
def __new__(cls, arg = None):
734+
if arg is None:
735+
return PythonList()
736+
elif isinstance(arg, PythonList):
735737
return arg
736738
elif isinstance(arg.shape[0], LiteralInteger):
737739
return PythonList(*[arg[i] for i in range(arg.shape[0])])
@@ -774,6 +776,11 @@ def __init__(self, *args):
774776
super().__init__()
775777
if pyccel_stage == 'syntactic':
776778
return
779+
elif len(args) == 0:
780+
self._shape = (LiteralInteger(0),)
781+
self._class_type = HomogeneousSetType(GenericType())
782+
return
783+
777784
arg0 = args[0]
778785
is_homogeneous = arg0.class_type is not GenericType() and \
779786
all(a.class_type is not GenericType() and \
@@ -829,8 +836,10 @@ class PythonSetFunction(PyccelFunction):
829836

830837
__slots__ = ('_shape', '_class_type')
831838
name = 'set'
832-
def __new__(cls, arg):
833-
if isinstance(arg.class_type, HomogeneousSetType):
839+
def __new__(cls, arg = None):
840+
if arg is None:
841+
return PythonSet()
842+
elif isinstance(arg.class_type, HomogeneousSetType):
834843
return arg
835844
elif isinstance(arg, (PythonList, PythonSet, PythonTuple)):
836845
return PythonSet(*arg)

pyccel/codegen/pipeline.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -335,13 +335,19 @@ def handle_error(stage):
335335
verbose = verbose)
336336

337337
mod_obj.add_dependencies(stdlib)
338-
338+
modules.append(stdlib)
339339

340340
# Iterate over the external_libs list and determine if the printer
341341
# requires an external lib to be included.
342342
for key, import_node in codegen.get_printer_imports().items():
343343
try:
344-
deps = generate_extension_modules(key, import_node, pyccel_dirpath, language)
344+
deps = generate_extension_modules(key, import_node, pyccel_dirpath,
345+
includes = includes,
346+
libs = compile_libs,
347+
libdirs = libdirs,
348+
dependencies = modules,
349+
accelerators = accelerators,
350+
language = language)
345351
except NotImplementedError as error:
346352
errors.report(f'{error}\n'+PYCCEL_RESTRICTION_TODO,
347353
severity='error',

pyccel/codegen/printing/fcode.py

Lines changed: 51 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from pyccel.ast.builtins import PythonTuple, DtypePrecisionToCastFunction
2525
from pyccel.ast.builtins import PythonBool, PythonList, PythonSet
2626

27-
from pyccel.ast.core import FunctionDef
27+
from pyccel.ast.core import FunctionDef, FunctionDefArgument, FunctionDefResult
2828
from pyccel.ast.core import SeparatorComment, Comment
2929
from pyccel.ast.core import ConstructorCall
3030
from pyccel.ast.core import FunctionCallArgument
@@ -37,7 +37,7 @@
3737
from pyccel.ast.datatypes import PrimitiveBooleanType, PrimitiveIntegerType, PrimitiveFloatingPointType, PrimitiveComplexType
3838
from pyccel.ast.datatypes import SymbolicType, StringType, FixedSizeNumericType, HomogeneousContainerType
3939
from pyccel.ast.datatypes import HomogeneousTupleType, HomogeneousListType, HomogeneousSetType, DictType
40-
from pyccel.ast.datatypes import PythonNativeInt
40+
from pyccel.ast.datatypes import PythonNativeInt, PythonNativeBool
4141
from pyccel.ast.datatypes import CustomDataType, InhomogeneousTupleType, TupleType
4242
from pyccel.ast.datatypes import pyccel_type_to_original_type, PyccelType
4343

@@ -550,22 +550,46 @@ def _build_gFTL_module(self, expr_type):
550550
include = Import(LiteralString('vector/template.inc'), Module('_', (), ()))
551551
element_type = expr_type.element_type
552552
if isinstance(element_type, FixedSizeNumericType):
553-
macros = [MacroDefinition('T', element_type.primitive_type),
553+
imports_and_macros = [MacroDefinition('T', element_type.primitive_type),
554554
MacroDefinition('T_KINDLEN(context)', KindSpecification(element_type))]
555555
else:
556-
macros = [MacroDefinition('T', element_type)]
556+
imports_and_macros = [MacroDefinition('T', element_type)]
557557
if isinstance(element_type, (NumpyNDArrayType, HomogeneousTupleType)):
558-
macros.append(MacroDefinition('T_rank', element_type.rank))
558+
imports_and_macros.append(MacroDefinition('T_rank', element_type.rank))
559559
elif not isinstance(element_type, FixedSizeNumericType):
560560
raise NotImplementedError("Support for lists of types defined in other modules is not yet implemented")
561-
macros.append(MacroDefinition('Vector', expr_type))
562-
macros.append(MacroDefinition('VectorIterator', IteratorType(expr_type)))
561+
imports_and_macros.append(MacroDefinition('Vector', expr_type))
562+
imports_and_macros.append(MacroDefinition('VectorIterator', IteratorType(expr_type)))
563+
elif isinstance(expr_type, HomogeneousSetType):
564+
include = Import(LiteralString('set/template.inc'), Module('_', (), ()))
565+
element_type = expr_type.element_type
566+
imports_and_macros = []
567+
if isinstance(element_type, FixedSizeNumericType):
568+
tmpVar_x = Variable(element_type, 'x')
569+
tmpVar_y = Variable(element_type, 'y')
570+
if isinstance(element_type.primitive_type, PrimitiveComplexType):
571+
complex_tool_import = Import('pyc_tools_f90', Module('pyc_tools_f90',(),()))
572+
self.add_import(complex_tool_import)
573+
imports_and_macros.append(complex_tool_import)
574+
compare_func = FunctionDef('complex_comparison',
575+
[FunctionDefArgument(tmpVar_x), FunctionDefArgument(tmpVar_y)],
576+
[FunctionDefResult(Variable(PythonNativeBool(), 'c'))], [])
577+
lt_def = FunctionCall(compare_func, [tmpVar_x, tmpVar_y])
578+
else:
579+
lt_def = PyccelAssociativeParenthesis(PyccelLt(tmpVar_x, tmpVar_y))
580+
imports_and_macros.extend([MacroDefinition('T', element_type.primitive_type),
581+
MacroDefinition('T_KINDLEN(context)', KindSpecification(element_type)),
582+
MacroDefinition('T_LT(x,y)', lt_def)])
583+
else:
584+
raise NotImplementedError("Support for sets of types which define their own < operator is not yet implemented")
585+
imports_and_macros.append(MacroDefinition('Set', expr_type))
586+
imports_and_macros.append(MacroDefinition('SetIterator', IteratorType(expr_type)))
563587
else:
564588
raise NotImplementedError(f"Unkown gFTL import for type {expr_type}")
565589

566590
typename = self._print(expr_type)
567591
mod_name = f'{typename}_mod'
568-
module = Module(mod_name, (), (), scope = Scope(), imports = [*macros, include],
592+
module = Module(mod_name, (), (), scope = Scope(), imports = [*imports_and_macros, include],
569593
is_external = True)
570594

571595
self._generated_gFTL_extensions[expr_type] = module
@@ -1091,6 +1115,21 @@ def _print_PythonList(self, expr):
10911115
vec_type = self._print(expr.class_type)
10921116
return f'{vec_type}({list_arg})'
10931117

1118+
def _print_PythonSet(self, expr):
1119+
if len(expr.args) == 0:
1120+
list_arg = ''
1121+
assign = expr.get_direct_user_nodes(lambda a : isinstance(a, Assign))
1122+
if assign:
1123+
set_type = self._print(assign[0].lhs.class_type)
1124+
else:
1125+
raise errors.report("Can't use an empty set without assigning it to a variable as the type cannot be deduced",
1126+
severity='fatal', symbol=expr)
1127+
1128+
else:
1129+
list_arg = self._print_PythonTuple(expr)
1130+
set_type = self._print(expr.class_type)
1131+
return f'{set_type}({list_arg})'
1132+
10941133
def _print_InhomogeneousTupleVariable(self, expr):
10951134
fs = ', '.join(self._print(f) for f in expr)
10961135
return '[{0}]'.format(fs)
@@ -1967,7 +2006,7 @@ def _print_Deallocate(self, expr):
19672006
Pyccel_del_args = [FunctionCallArgument(var)]
19682007
return self._print(FunctionCall(Pyccel__del, Pyccel_del_args))
19692008

1970-
if var.is_alias or isinstance(class_type, HomogeneousListType):
2009+
if var.is_alias or isinstance(class_type, (HomogeneousListType, HomogeneousSetType)):
19712010
return ''
19722011
elif isinstance(class_type, (NumpyNDArrayType, HomogeneousTupleType, StringType)):
19732012
var_code = self._print(var)
@@ -2011,6 +2050,9 @@ def _print_PythonNativeBool(self, expr):
20112050
def _print_HomogeneousListType(self, expr):
20122051
return 'Vector_'+self._print(expr.element_type)
20132052

2053+
def _print_HomogeneousSetType(self, expr):
2054+
return 'Set_'+self._print(expr.element_type)
2055+
20142056
def _print_IteratorType(self, expr):
20152057
iterable_type = self._print(expr.iterable_type)
20162058
return f"{iterable_type}_Iterator"

pyccel/codegen/utilities.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
"ndarrays" : ("ndarrays", CompileObj("ndarrays.c",folder="ndarrays")),
4141
"pyc_math_f90" : ("math", CompileObj("pyc_math_f90.f90",folder="math")),
4242
"pyc_math_c" : ("math", CompileObj("pyc_math_c.c",folder="math")),
43+
"pyc_tools_f90" : ("tools", CompileObj("pyc_tools_f90.f90",folder="tools")),
4344
"cwrapper" : ("cwrapper", CompileObj("cwrapper.c",folder="cwrapper", accelerators=('python',))),
4445
"numpy_f90" : ("numpy", CompileObj("numpy_f90.f90",folder="numpy")),
4546
"numpy_c" : ("numpy", CompileObj("numpy_c.c",folder="numpy")),
@@ -180,7 +181,9 @@ def copy_internal_library(lib_folder, pyccel_dirpath, extra_files = None):
180181
return lib_dest_path
181182

182183
#==============================================================================
183-
def generate_extension_modules(import_key, import_node, pyccel_dirpath, language):
184+
def generate_extension_modules(import_key, import_node, pyccel_dirpath,
185+
includes, libs, libdirs, dependencies,
186+
accelerators, language):
184187
"""
185188
Generate any new modules that describe extensions.
186189
@@ -196,6 +199,16 @@ def generate_extension_modules(import_key, import_node, pyccel_dirpath, language
196199
be printed).
197200
pyccel_dirpath : str
198201
The folder where files are being saved.
202+
includes : iterable of strs
203+
Include directories paths.
204+
libs : iterable of strs
205+
Required libraries.
206+
libdirs : iterable of strs
207+
Paths to directories containing the required libraries.
208+
dependencies : iterable of CompileObjs
209+
Objects which must also be compiled in order to compile this module/program.
210+
accelerators : iterable of str
211+
Tool used to accelerate the code (e.g. openmp openacc).
199212
language : str
200213
The language in which code is being printed.
201214
@@ -221,7 +234,9 @@ def generate_extension_modules(import_key, import_node, pyccel_dirpath, language
221234
f.write(code)
222235

223236
new_dependencies.append(CompileObj(os.path.basename(filename), folder=folder,
224-
includes=(os.path.join(pyccel_dirpath, 'gFTL'),)))
237+
includes=(os.path.join(pyccel_dirpath, 'gFTL'), *includes),
238+
libs=libs, libdirs=libdirs, dependencies=dependencies,
239+
accelerators=accelerators))
225240

226241
if lib_name in external_libs:
227242
copy_internal_library(lib_name, pyccel_dirpath)
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
! -------------------------------------------------------------------------------------- !
2+
! This file is part of Pyccel which is released under MIT License. See the LICENSE file !
3+
! or go to https://github.com/pyccel/pyccel/blob/devel/LICENSE for full license details. !
4+
! -------------------------------------------------------------------------------------- !
5+
6+
module pyc_tools_f90
7+
8+
use, intrinsic :: ISO_C_Binding, only : b1 => C_BOOL, &
9+
f32 => C_FLOAT, &
10+
f64 => C_DOUBLE, &
11+
c64 => C_DOUBLE_COMPLEX, &
12+
c32 => C_FLOAT_COMPLEX
13+
14+
implicit none
15+
16+
interface complex_comparison
17+
module procedure complex_comparison_4
18+
module procedure complex_comparison_8
19+
end interface complex_comparison
20+
21+
contains
22+
23+
function complex_comparison_4(x, y) result(c)
24+
complex(c32) :: x
25+
complex(c32) :: y
26+
logical(b1) :: c
27+
real(f32) :: real_x
28+
real(f32) :: real_y
29+
30+
real_x = real(x)
31+
real_y = real(y)
32+
33+
c = merge(real_x < real_y, aimag(x) < aimag(y), real_x /= real_y)
34+
35+
end function complex_comparison_4
36+
37+
function complex_comparison_8(x, y) result(c)
38+
complex(c64) :: x
39+
complex(c64) :: y
40+
logical(b1) :: c
41+
real(f64) :: real_x
42+
real(f64) :: real_y
43+
44+
real_x = real(x)
45+
real_y = real(y)
46+
47+
c = merge(real_x < real_y, aimag(x) < aimag(y), real_x /= real_y)
48+
49+
end function complex_comparison_8
50+
51+
end module pyc_tools_f90

0 commit comments

Comments
 (0)