Skip to content

Commit 53ed828

Browse files
authored
Add Fortran support for set method pop (pyccel#2014)
Add Fortran support for set method `pop`. Fixes pyccel#2010 **Commit Summary** - Add `MacroUndef` class to avoid macros polluting namespaces - Simplify code using `_get_comparison_operator` to avoid duplication - Add `_build_gFTL_extension_module` to create extensions for gFTL types - Add support for `set.pop` - Fix Python printing so expressions that can be used in expressions don't end in `\n` - Activate `set.pop` tests - Activate `test_init_` functions which don't require specific language support
1 parent 08efed8 commit 53ed828

File tree

9 files changed

+204
-58
lines changed

9 files changed

+204
-58
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ All notable changes to this project will be documented in this file.
3636
- #1689 : Add C and Fortran support for list method `append()`.
3737
- #1876 : Add C support for indexing lists.
3838
- #1690 : Add C support for list method `pop()`.
39-
- #1877 : Add C Support for set method `pop()`.
39+
- #1877 : Add C and Fortran Support for set method `pop()`.
4040
- #1917 : Add C and Fortran support for set method `add()`.
4141
- #1918 : Add C and Fortran support for set method `clear()`.
4242
- #1936 : Add missing C output for inline decorator example in documentation

docs/builtin-functions.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ Python contains a limited number of builtin functions defined [here](https://doc
107107
| `isdisjoint` | No |
108108
| `issubset` | No |
109109
| `issuperset` | No |
110-
| `pop` | C and Python |
110+
| **`pop`** | **Yes** |
111111
| `remove` | Python-only |
112112
| `symmetric_difference` | No |
113113
| `symmetric_difference_update` | No |

pyccel/ast/low_level_tools.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212

1313
__all__ = ('IteratorType',
1414
'PairType',
15-
'MacroDefinition')
15+
'MacroDefinition',
16+
'MacroUndef')
1617

1718
#------------------------------------------------------------------------------
1819
class IteratorType(PyccelType, metaclass=ArgumentSingleton):
@@ -127,3 +128,32 @@ def object(self):
127128
"""
128129
return self._obj
129130

131+
#------------------------------------------------------------------------------
132+
class MacroUndef(PyccelAstNode):
133+
"""
134+
A class for undefining a macro in a file.
135+
136+
A class for undefining a macro in a file.
137+
138+
Parameters
139+
----------
140+
macro_name : str
141+
The name of the macro.
142+
"""
143+
_attribute_nodes = ()
144+
__slots__ = ('_macro_name',)
145+
146+
def __init__(self, macro_name):
147+
assert isinstance(macro_name, str)
148+
self._macro_name = macro_name
149+
super().__init__()
150+
151+
@property
152+
def macro_name(self):
153+
"""
154+
The name of the macro being undefined.
155+
156+
The name of the macro being undefined.
157+
"""
158+
return self._macro_name
159+

pyccel/codegen/printing/fcode.py

Lines changed: 120 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
from pyccel.ast.literals import Nil
5353

5454
from pyccel.ast.low_level_tools import MacroDefinition, IteratorType, PairType
55+
from pyccel.ast.low_level_tools import MacroUndef
5556

5657
from pyccel.ast.mathext import math_constants
5758

@@ -265,6 +266,7 @@ def __init__(self, filename, prefix_module = None):
265266

266267
self.prefix_module = prefix_module
267268

269+
self._generated_gFTL_types = {}
268270
self._generated_gFTL_extensions = {}
269271

270272
def print_constant_imports(self):
@@ -522,6 +524,41 @@ def _calculate_class_names(self, expr):
522524
scope.rename_function(f, suggested_name)
523525
f.cls_name = scope.get_new_name(f'{name}_{f.name}')
524526

527+
def _get_comparison_operator(self, element_type, imports_and_macros):
528+
"""
529+
Get an AST node describing a comparison operator between two objects of the same type.
530+
531+
Get an AST node describing a comparison operator between two objects of the same type.
532+
This is useful when defining gFTL modules.
533+
534+
Parameters
535+
----------
536+
element_type : PyccelType
537+
The data type to be compared.
538+
539+
imports_and_macros : list
540+
A list of imports or macros for the gFTL module.
541+
542+
Returns
543+
-------
544+
PyccelAstNode
545+
An AST node describing a comparison operator.
546+
"""
547+
tmpVar_x = Variable(element_type, 'x')
548+
tmpVar_y = Variable(element_type, 'y')
549+
if isinstance(element_type.primitive_type, PrimitiveComplexType):
550+
complex_tool_import = Import('pyc_tools_f90', Module('pyc_tools_f90',(),()))
551+
self.add_import(complex_tool_import)
552+
imports_and_macros.append(complex_tool_import)
553+
compare_func = FunctionDef('complex_comparison',
554+
[FunctionDefArgument(tmpVar_x), FunctionDefArgument(tmpVar_y)],
555+
[FunctionDefResult(Variable(PythonNativeBool(), 'c'))], [])
556+
lt_def = FunctionCall(compare_func, [tmpVar_x, tmpVar_y])
557+
else:
558+
lt_def = PyccelAssociativeParenthesis(PyccelLt(tmpVar_x, tmpVar_y))
559+
560+
return lt_def
561+
525562
def _build_gFTL_module(self, expr_type):
526563
"""
527564
Build the gFTL module to create container types.
@@ -541,13 +578,12 @@ def _build_gFTL_module(self, expr_type):
541578
The import which allows the new type to be accessed.
542579
"""
543580
# Get the type used in the dict for compatible types (e.g. float vs float64)
544-
matching_expr_type = next((t for t in self._generated_gFTL_extensions if expr_type == t), None)
581+
matching_expr_type = next((t for t in self._generated_gFTL_types if expr_type == t), None)
545582
if matching_expr_type:
546-
module = self._generated_gFTL_extensions[matching_expr_type]
583+
module = self._generated_gFTL_types[matching_expr_type]
547584
mod_name = module.name
548585
else:
549586
if isinstance(expr_type, HomogeneousListType):
550-
include = Import(LiteralString('vector/template.inc'), Module('_', (), ()))
551587
element_type = expr_type.element_type
552588
if isinstance(element_type, FixedSizeNumericType):
553589
imports_and_macros = [MacroDefinition('T', element_type.primitive_type),
@@ -558,50 +594,32 @@ def _build_gFTL_module(self, expr_type):
558594
imports_and_macros.append(MacroDefinition('T_rank', element_type.rank))
559595
elif not isinstance(element_type, FixedSizeNumericType):
560596
raise NotImplementedError("Support for lists of types defined in other modules is not yet implemented")
561-
imports_and_macros.append(MacroDefinition('Vector', expr_type))
562-
imports_and_macros.append(MacroDefinition('VectorIterator', IteratorType(expr_type)))
597+
imports_and_macros.extend([MacroDefinition('Vector', expr_type),
598+
MacroDefinition('VectorIterator', IteratorType(expr_type)),
599+
Import(LiteralString('vector/template.inc'), Module('_', (), ())),
600+
MacroUndef('Vector'),
601+
MacroUndef('VectorIterator'),])
563602
elif isinstance(expr_type, HomogeneousSetType):
564-
include = Import(LiteralString('set/template.inc'), Module('_', (), ()))
565603
element_type = expr_type.element_type
566604
imports_and_macros = []
567605
if isinstance(element_type, FixedSizeNumericType):
568-
tmpVar_x = Variable(element_type, 'x')
569-
tmpVar_y = Variable(element_type, 'y')
570-
if isinstance(element_type.primitive_type, PrimitiveComplexType):
571-
complex_tool_import = Import('pyc_tools_f90', Module('pyc_tools_f90',(),()))
572-
self.add_import(complex_tool_import)
573-
imports_and_macros.append(complex_tool_import)
574-
compare_func = FunctionDef('complex_comparison',
575-
[FunctionDefArgument(tmpVar_x), FunctionDefArgument(tmpVar_y)],
576-
[FunctionDefResult(Variable(PythonNativeBool(), 'c'))], [])
577-
lt_def = FunctionCall(compare_func, [tmpVar_x, tmpVar_y])
578-
else:
579-
lt_def = PyccelAssociativeParenthesis(PyccelLt(tmpVar_x, tmpVar_y))
606+
lt_def = self._get_comparison_operator(element_type, imports_and_macros)
580607
imports_and_macros.extend([MacroDefinition('T', element_type.primitive_type),
581608
MacroDefinition('T_KINDLEN(context)', KindSpecification(element_type)),
582609
MacroDefinition('T_LT(x,y)', lt_def)])
583610
else:
584611
raise NotImplementedError("Support for sets of types which define their own < operator is not yet implemented")
585-
imports_and_macros.append(MacroDefinition('Set', expr_type))
586-
imports_and_macros.append(MacroDefinition('SetIterator', IteratorType(expr_type)))
612+
imports_and_macros.extend([MacroDefinition('Set', expr_type),
613+
MacroDefinition('SetIterator', IteratorType(expr_type)),
614+
Import(LiteralString('set/template.inc'), Module('_', (), ())),
615+
MacroUndef('Set'),
616+
MacroUndef('SetIterator')])
587617
elif isinstance(expr_type, DictType):
588-
include = Import(LiteralString('map/template.inc'), Module('_', (), ()))
589618
key_type = expr_type.key_type
590619
value_type = expr_type.value_type
591620
imports_and_macros = []
592621
if isinstance(key_type, FixedSizeNumericType):
593-
tmpVar_x = Variable(key_type, 'x')
594-
tmpVar_y = Variable(key_type, 'y')
595-
if isinstance(key_type.primitive_type, PrimitiveComplexType):
596-
complex_tool_import = Import('pyc_tools_f90', Module('pyc_tools_f90',(),()))
597-
self.add_import(complex_tool_import)
598-
imports_and_macros.append(complex_tool_import)
599-
compare_func = FunctionDef('complex_comparison',
600-
[FunctionDefArgument(tmpVar_x), FunctionDefArgument(tmpVar_y)],
601-
[FunctionDefResult(Variable(PythonNativeBool(), 'c'))], [])
602-
lt_def = FunctionCall(compare_func, [tmpVar_x, tmpVar_y])
603-
else:
604-
lt_def = PyccelAssociativeParenthesis(PyccelLt(tmpVar_x, tmpVar_y))
622+
lt_def = self._get_comparison_operator(key_type, imports_and_macros)
605623
imports_and_macros.extend([MacroDefinition('Key', key_type.primitive_type),
606624
MacroDefinition('Key_KINDLEN(context)', KindSpecification(key_type)),
607625
MacroDefinition('Key_LT(x,y)', lt_def)])
@@ -612,15 +630,68 @@ def _build_gFTL_module(self, expr_type):
612630
MacroDefinition('T_KINDLEN(context)', KindSpecification(value_type))])
613631
else:
614632
raise NotImplementedError(f"Support for dictionary values of type {value_type} not yet implemented")
615-
imports_and_macros.append(MacroDefinition('Pair', PairType(key_type, value_type)))
616-
imports_and_macros.append(MacroDefinition('Map', expr_type))
617-
imports_and_macros.append(MacroDefinition('MapIterator', IteratorType(expr_type)))
633+
imports_and_macros.extend([MacroDefinition('Pair', PairType(key_type, value_type)),
634+
MacroDefinition('Map', expr_type),
635+
MacroDefinition('MapIterator', IteratorType(expr_type)),
636+
Import(LiteralString('map/template.inc'), Module('_', (), ())),
637+
MacroUndef('Pair'),
638+
MacroUndef('Map'),
639+
MacroUndef('MapIterator')])
618640
else:
619641
raise NotImplementedError(f"Unkown gFTL import for type {expr_type}")
620642

621643
typename = self._print(expr_type)
622644
mod_name = f'{typename}_mod'
623-
module = Module(mod_name, (), (), scope = Scope(), imports = [*imports_and_macros, include],
645+
module = Module(mod_name, (), (), scope = Scope(), imports = imports_and_macros,
646+
is_external = True)
647+
648+
self._generated_gFTL_types[expr_type] = module
649+
650+
return Import(f'gFTL_extensions/{mod_name}', module)
651+
652+
def _build_gFTL_extension_module(self, expr_type):
653+
"""
654+
Build the gFTL module to create container extension functions.
655+
656+
Create a module which will import the gFTL include files
657+
in order to create container types (e.g lists, sets, etc).
658+
The name of the module is derived from the name of the type.
659+
660+
Parameters
661+
----------
662+
expr_type : DataType
663+
The data type for which extensions are required.
664+
665+
Returns
666+
-------
667+
Import
668+
The import which allows the new type to be accessed.
669+
"""
670+
# Get the type used in the dict for compatible types (e.g. float vs float64)
671+
matching_expr_type = next((t for t in self._generated_gFTL_types if expr_type == t), None)
672+
matching_expr_extensions = next((t for t in self._generated_gFTL_extensions if expr_type == t), None)
673+
typename = self._print(expr_type)
674+
mod_name = f'{typename}_extensions_mod'
675+
if matching_expr_extensions:
676+
module = self._generated_gFTL_extensions[matching_expr_extensions]
677+
else:
678+
if matching_expr_type is None:
679+
matching_expr_type = self._build_gFTL_module(expr_type)
680+
self.add_import(matching_expr_type)
681+
682+
type_module = matching_expr_type.source_module
683+
684+
if isinstance(expr_type, HomogeneousSetType):
685+
set_filename = LiteralString('set/template.inc')
686+
imports_and_macros = [Import(LiteralString('Set_extensions.inc'), Module('_', (), ())) \
687+
if getattr(i, 'source', None) == set_filename else i \
688+
for i in type_module.imports]
689+
imports_and_macros.insert(0, matching_expr_type)
690+
self.add_import(Import('gFTL_functions/Set_extensions', Module('_', (), ()), ignore_at_print = True))
691+
else:
692+
raise NotImplementedError(f"Unkown gFTL import for type {expr_type}")
693+
694+
module = Module(mod_name, (), (), scope = Scope(), imports = imports_and_macros,
624695
is_external = True)
625696

626697
self._generated_gFTL_extensions[expr_type] = module
@@ -1253,6 +1324,14 @@ def _print_SetClear(self, expr):
12531324
var = self._print(expr.set_variable)
12541325
return f'call {var} % clear()\n'
12551326

1327+
def _print_SetPop(self, expr):
1328+
var = expr.set_variable
1329+
expr_type = var.class_type
1330+
var_code = self._print(expr.set_variable)
1331+
type_name = self._print(expr_type)
1332+
self.add_import(self._build_gFTL_extension_module(expr_type))
1333+
return f'{type_name}_pop({var_code})\n'
1334+
12561335
#========================== Numpy Elements ===============================#
12571336

12581337
def _print_NumpySum(self, expr):
@@ -3596,5 +3675,9 @@ def _print_MacroDefinition(self, expr):
35963675
obj = self._print(expr.object)
35973676
return f'#define {name} {obj}\n'
35983677

3678+
def _print_MacroUndef(self, expr):
3679+
name = expr.macro_name
3680+
return f'#undef {name}\n'
3681+
35993682
def _print_KindSpecification(self, expr):
36003683
return f'(kind = {self.print_kind(expr.type_specifier)})'

pyccel/codegen/printing/pycode.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from pyccel.ast.builtins import PythonComplex, DtypePrecisionToCastFunction
1212
from pyccel.ast.core import CodeBlock, Import, Assign, FunctionCall, For, AsName, FunctionAddress
1313
from pyccel.ast.core import IfSection, FunctionDef, Module, PyccelFunctionDef
14-
from pyccel.ast.datatypes import HomogeneousTupleType, HomogeneousListType
14+
from pyccel.ast.datatypes import HomogeneousTupleType, VoidType
1515
from pyccel.ast.functionalexpr import FunctionalFor
1616
from pyccel.ast.literals import LiteralTrue, LiteralString, LiteralInteger
1717
from pyccel.ast.numpyext import numpy_target_swap
@@ -841,9 +841,9 @@ def _print_DictPop(self, expr):
841841
key = self._print(expr.key)
842842
if expr.default_value:
843843
val = self._print(expr.default_value)
844-
return f"{dict_obj}.pop({key}, {val})\n"
844+
return f"{dict_obj}.pop({key}, {val})"
845845
else:
846-
return f"{dict_obj}.pop({key})\n"
846+
return f"{dict_obj}.pop({key})"
847847

848848
def _print_DictGet(self, expr):
849849
dict_obj = self._print(expr.dict_obj)
@@ -871,7 +871,11 @@ def _print_SetMethod(self, expr):
871871
name = expr.name
872872
args = "" if len(expr.args) == 0 or expr.args[-1] is None \
873873
else ', '.join(self._print(a) for a in expr.args)
874-
return f"{set_var}.{name}({args})\n"
874+
code = f"{set_var}.{name}({args})"
875+
if expr.class_type is VoidType():
876+
return f'{code}\n'
877+
else:
878+
return code
875879

876880
def _print_Nil(self, expr):
877881
return 'None'

pyccel/codegen/utilities.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
"numpy_c" : ("numpy", CompileObj("numpy_c.c",folder="numpy")),
4747
"Set_extensions" : ("STC_Extensions", CompileObj("Set_Extensions.h", folder="STC_Extensions", has_target_file = False)),
4848
"List_extensions" : ("STC_Extensions", CompileObj("List_Extensions.h", folder="STC_Extensions", has_target_file = False)),
49+
"gFTL_functions/Set_extensions" : ("gFTL_functions", CompileObj("Set_Extensions.inc", folder="gFTL_functions", has_target_file = False)),
4950
}
5051
internal_libs["cwrapper_ndarrays"] = ("cwrapper_ndarrays", CompileObj("cwrapper_ndarrays.c",folder="cwrapper_ndarrays",
5152
accelerators = ('python',),
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#include <set/header.inc>
2+
3+
contains
4+
5+
#define __IDENTITY(x) x
6+
#define __guard __set_guard
7+
#include "parameters/T/copy_set_T_to_internal_T.inc"
8+
#include "parameters/T/define_derived_macros.inc"
9+
10+
function __IDENTITY(Set)_pop(my_set) result(result)
11+
class(Set), intent(inout) :: my_set
12+
__T_declare_dummy__ :: result
13+
14+
type(SetIterator) :: iter1
15+
type(SetIterator) :: iter2
16+
17+
iter1 = my_set%begin()
18+
19+
result = iter1%of()
20+
21+
iter2 = my_set%erase(iter1)
22+
23+
end function __IDENTITY(Set)_pop
24+
25+
#include <set/tail.inc>

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ include = [
5858
"pyccel/stdlib/**/*.h",
5959
"pyccel/stdlib/**/*.c",
6060
"pyccel/stdlib/**/*.f90",
61+
"pyccel/stdlib/**/*.inc",
6162
"pyccel/extensions/STC/include"
6263
]
6364
exclude = [

0 commit comments

Comments
 (0)