Skip to content

Commit 286e1fc

Browse files
authored
Fix multiple calls to _build_gFTL_extension_module (pyccel#2114)
Fix error when there are multiple calls to `_build_gFTL_extension_module`. Fixes pyccel#2103 . This error is due to illogical usage of `matching_expr_type`. The error was introduced since the last version so this is not added to the CHANGELOG.
1 parent b5abf6f commit 286e1fc

File tree

3 files changed

+40
-8
lines changed

3 files changed

+40
-8
lines changed

pyccel/codegen/printing/fcode.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -677,26 +677,31 @@ def _build_gFTL_extension_module(self, expr_type):
677677
Import
678678
The import which allows the new type to be accessed.
679679
"""
680-
# Get the type used in the dict for compatible types (e.g. float vs float64)
681-
matching_expr_type = next((t for t in self._generated_gFTL_types if expr_type == t), None)
682-
matching_expr_extensions = next((t for t in self._generated_gFTL_extensions if expr_type == t), None)
680+
# Get the module describing the type
681+
matching_expr_type = next((m for t,m in self._generated_gFTL_types.items() if expr_type == t), None)
682+
# Get the module describing the extension
683+
matching_expr_extensions = next((m for t,m in self._generated_gFTL_extensions.items() if expr_type == t), None)
683684
typename = self._print(expr_type)
684685
mod_name = f'{typename}_extensions_mod'
685686
if matching_expr_extensions:
686-
module = self._generated_gFTL_extensions[matching_expr_extensions]
687+
module = matching_expr_extensions
687688
else:
689+
imports_and_macros = []
690+
688691
if matching_expr_type is None:
689692
matching_expr_type = self._build_gFTL_module(expr_type)
690693
self.add_import(matching_expr_type)
691-
692-
type_module = matching_expr_type.source_module
694+
imports_and_macros.append(matching_expr_type)
695+
type_module = matching_expr_type.source_module
696+
else:
697+
type_module = matching_expr_type
698+
imports_and_macros.append(Import(f'gFTL_extensions/{mod_name}', type_module))
693699

694700
if isinstance(expr_type, HomogeneousSetType):
695701
set_filename = LiteralString('set/template.inc')
696-
imports_and_macros = [Import(LiteralString('Set_extensions.inc'), Module('_', (), ())) \
702+
imports_and_macros += [Import(LiteralString('Set_extensions.inc'), Module('_', (), ())) \
697703
if getattr(i, 'source', None) == set_filename else i \
698704
for i in type_module.imports]
699-
imports_and_macros.insert(0, matching_expr_type)
700705
self.add_import(Import('gFTL_functions/Set_extensions', Module('_', (), ()), ignore_at_print = True))
701706
else:
702707
raise NotImplementedError(f"Unkown gFTL import for type {expr_type}")

tests/epyccel/modules/Module_11.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# pylint: disable=missing-function-docstring, missing-module-docstring
2+
3+
def update_multiple():
4+
a = {1, 2, 3}
5+
a.update({4, 5})
6+
return len(a)
7+
8+
def set_union():
9+
a = {1,2,3,4,5}
10+
b = {4,5,6}
11+
c = a.union(b)
12+
return len(c)

tests/epyccel/test_epyccel_modules.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,3 +243,18 @@ def test_module_type_alias_expression(language):
243243
assert np.isclose( max_pyt, max_pyc, rtol=1e-14, atol=1e-14 )
244244
assert np.allclose( x, x_pyc, rtol=1e-14, atol=1e-14 )
245245
assert np.allclose( y, y_pyc, rtol=1e-14, atol=1e-14 )
246+
247+
def test_module_11(language):
248+
import modules.Module_11 as mod
249+
250+
modnew = epyccel(mod, language=language)
251+
252+
len_pyt = mod.update_multiple()
253+
len_pyc = modnew.update_multiple()
254+
255+
assert len_pyt == len_pyc
256+
257+
len_pyt = mod.set_union()
258+
len_pyc = modnew.set_union()
259+
260+
assert len_pyt == len_pyc

0 commit comments

Comments
 (0)