|
24 | 24 | from pyccel.ast.builtins import PythonTuple, DtypePrecisionToCastFunction
|
25 | 25 | from pyccel.ast.builtins import PythonBool, PythonList, PythonSet
|
26 | 26 |
|
27 |
| -from pyccel.ast.core import FunctionDef |
| 27 | +from pyccel.ast.core import FunctionDef, FunctionDefArgument, FunctionDefResult |
28 | 28 | from pyccel.ast.core import SeparatorComment, Comment
|
29 | 29 | from pyccel.ast.core import ConstructorCall
|
30 | 30 | from pyccel.ast.core import FunctionCallArgument
|
|
37 | 37 | from pyccel.ast.datatypes import PrimitiveBooleanType, PrimitiveIntegerType, PrimitiveFloatingPointType, PrimitiveComplexType
|
38 | 38 | from pyccel.ast.datatypes import SymbolicType, StringType, FixedSizeNumericType, HomogeneousContainerType
|
39 | 39 | from pyccel.ast.datatypes import HomogeneousTupleType, HomogeneousListType, HomogeneousSetType, DictType
|
40 |
| -from pyccel.ast.datatypes import PythonNativeInt |
| 40 | +from pyccel.ast.datatypes import PythonNativeInt, PythonNativeBool |
41 | 41 | from pyccel.ast.datatypes import CustomDataType, InhomogeneousTupleType, TupleType
|
42 | 42 | from pyccel.ast.datatypes import pyccel_type_to_original_type, PyccelType
|
43 | 43 |
|
@@ -550,22 +550,46 @@ def _build_gFTL_module(self, expr_type):
|
550 | 550 | include = Import(LiteralString('vector/template.inc'), Module('_', (), ()))
|
551 | 551 | element_type = expr_type.element_type
|
552 | 552 | if isinstance(element_type, FixedSizeNumericType):
|
553 |
| - macros = [MacroDefinition('T', element_type.primitive_type), |
| 553 | + imports_and_macros = [MacroDefinition('T', element_type.primitive_type), |
554 | 554 | MacroDefinition('T_KINDLEN(context)', KindSpecification(element_type))]
|
555 | 555 | else:
|
556 |
| - macros = [MacroDefinition('T', element_type)] |
| 556 | + imports_and_macros = [MacroDefinition('T', element_type)] |
557 | 557 | if isinstance(element_type, (NumpyNDArrayType, HomogeneousTupleType)):
|
558 |
| - macros.append(MacroDefinition('T_rank', element_type.rank)) |
| 558 | + imports_and_macros.append(MacroDefinition('T_rank', element_type.rank)) |
559 | 559 | elif not isinstance(element_type, FixedSizeNumericType):
|
560 | 560 | raise NotImplementedError("Support for lists of types defined in other modules is not yet implemented")
|
561 |
| - macros.append(MacroDefinition('Vector', expr_type)) |
562 |
| - macros.append(MacroDefinition('VectorIterator', IteratorType(expr_type))) |
| 561 | + imports_and_macros.append(MacroDefinition('Vector', expr_type)) |
| 562 | + imports_and_macros.append(MacroDefinition('VectorIterator', IteratorType(expr_type))) |
| 563 | + elif isinstance(expr_type, HomogeneousSetType): |
| 564 | + include = Import(LiteralString('set/template.inc'), Module('_', (), ())) |
| 565 | + element_type = expr_type.element_type |
| 566 | + imports_and_macros = [] |
| 567 | + if isinstance(element_type, FixedSizeNumericType): |
| 568 | + tmpVar_x = Variable(element_type, 'x') |
| 569 | + tmpVar_y = Variable(element_type, 'y') |
| 570 | + if isinstance(element_type.primitive_type, PrimitiveComplexType): |
| 571 | + complex_tool_import = Import('pyc_tools_f90', Module('pyc_tools_f90',(),())) |
| 572 | + self.add_import(complex_tool_import) |
| 573 | + imports_and_macros.append(complex_tool_import) |
| 574 | + compare_func = FunctionDef('complex_comparison', |
| 575 | + [FunctionDefArgument(tmpVar_x), FunctionDefArgument(tmpVar_y)], |
| 576 | + [FunctionDefResult(Variable(PythonNativeBool(), 'c'))], []) |
| 577 | + lt_def = FunctionCall(compare_func, [tmpVar_x, tmpVar_y]) |
| 578 | + else: |
| 579 | + lt_def = PyccelAssociativeParenthesis(PyccelLt(tmpVar_x, tmpVar_y)) |
| 580 | + imports_and_macros.extend([MacroDefinition('T', element_type.primitive_type), |
| 581 | + MacroDefinition('T_KINDLEN(context)', KindSpecification(element_type)), |
| 582 | + MacroDefinition('T_LT(x,y)', lt_def)]) |
| 583 | + else: |
| 584 | + raise NotImplementedError("Support for sets of types which define their own < operator is not yet implemented") |
| 585 | + imports_and_macros.append(MacroDefinition('Set', expr_type)) |
| 586 | + imports_and_macros.append(MacroDefinition('SetIterator', IteratorType(expr_type))) |
563 | 587 | else:
|
564 | 588 | raise NotImplementedError(f"Unkown gFTL import for type {expr_type}")
|
565 | 589 |
|
566 | 590 | typename = self._print(expr_type)
|
567 | 591 | mod_name = f'{typename}_mod'
|
568 |
| - module = Module(mod_name, (), (), scope = Scope(), imports = [*macros, include], |
| 592 | + module = Module(mod_name, (), (), scope = Scope(), imports = [*imports_and_macros, include], |
569 | 593 | is_external = True)
|
570 | 594 |
|
571 | 595 | self._generated_gFTL_extensions[expr_type] = module
|
@@ -1091,6 +1115,21 @@ def _print_PythonList(self, expr):
|
1091 | 1115 | vec_type = self._print(expr.class_type)
|
1092 | 1116 | return f'{vec_type}({list_arg})'
|
1093 | 1117 |
|
| 1118 | + def _print_PythonSet(self, expr): |
| 1119 | + if len(expr.args) == 0: |
| 1120 | + list_arg = '' |
| 1121 | + assign = expr.get_direct_user_nodes(lambda a : isinstance(a, Assign)) |
| 1122 | + if assign: |
| 1123 | + set_type = self._print(assign[0].lhs.class_type) |
| 1124 | + else: |
| 1125 | + raise errors.report("Can't use an empty set without assigning it to a variable as the type cannot be deduced", |
| 1126 | + severity='fatal', symbol=expr) |
| 1127 | + |
| 1128 | + else: |
| 1129 | + list_arg = self._print_PythonTuple(expr) |
| 1130 | + set_type = self._print(expr.class_type) |
| 1131 | + return f'{set_type}({list_arg})' |
| 1132 | + |
1094 | 1133 | def _print_InhomogeneousTupleVariable(self, expr):
|
1095 | 1134 | fs = ', '.join(self._print(f) for f in expr)
|
1096 | 1135 | return '[{0}]'.format(fs)
|
@@ -1967,7 +2006,7 @@ def _print_Deallocate(self, expr):
|
1967 | 2006 | Pyccel_del_args = [FunctionCallArgument(var)]
|
1968 | 2007 | return self._print(FunctionCall(Pyccel__del, Pyccel_del_args))
|
1969 | 2008 |
|
1970 |
| - if var.is_alias or isinstance(class_type, HomogeneousListType): |
| 2009 | + if var.is_alias or isinstance(class_type, (HomogeneousListType, HomogeneousSetType)): |
1971 | 2010 | return ''
|
1972 | 2011 | elif isinstance(class_type, (NumpyNDArrayType, HomogeneousTupleType, StringType)):
|
1973 | 2012 | var_code = self._print(var)
|
@@ -2011,6 +2050,9 @@ def _print_PythonNativeBool(self, expr):
|
2011 | 2050 | def _print_HomogeneousListType(self, expr):
|
2012 | 2051 | return 'Vector_'+self._print(expr.element_type)
|
2013 | 2052 |
|
| 2053 | + def _print_HomogeneousSetType(self, expr): |
| 2054 | + return 'Set_'+self._print(expr.element_type) |
| 2055 | + |
2014 | 2056 | def _print_IteratorType(self, expr):
|
2015 | 2057 | iterable_type = self._print(expr.iterable_type)
|
2016 | 2058 | return f"{iterable_type}_Iterator"
|
|
0 commit comments