Skip to content

Commit 0341065

Browse files
Fix broken chaining of files using STC imports (pyccel#2086)
This PR aims to resolve issue pyccel#2084 by systematically including extension headers whenever STC headers are included to avoid double includes. --------- Co-authored-by: Emily Bourne <[email protected]>
1 parent 9b0545f commit 0341065

File tree

3 files changed

+32
-47
lines changed

3 files changed

+32
-47
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ All notable changes to this project will be documented in this file.
8484
- #2039 : Ensure any expressions in the iterable of a for loop are calculated before the loop.
8585
- #2013 : Stop limiting the length of strings to 128 characters.
8686
- #2078 : Fix translation of classes containing comments.
87+
- #2041 : Include all type extension methods by default.
8788

8889
### Changed
8990

pyccel/codegen/printing/ccode.py

Lines changed: 31 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -249,14 +249,13 @@
249249
'stdbool',
250250
'assert']}
251251

252-
import_header_guard_prefix = {'Set_extensions' : '_TOOLS_SET',
253-
'List_extensions' : '_TOOLS_LIST',
254-
'Common_extensions' : '_TOOLS_COMMON'}
252+
import_header_guard_prefix = {'stc/hset' : '_TOOLS_SET',
253+
'stc/vec' : '_TOOLS_LIST',
254+
'stc/common' : '_TOOLS_COMMON'}
255255

256-
257-
stc_header_mapping = {'List_extensions': 'stc/vec',
258-
'Set_extensions': 'stc/hset',
259-
'Common_extensions': 'stc/common'}
256+
stc_extension_mapping = {'stc/vec': 'List_extensions',
257+
'stc/hset' : 'Set_extensions',
258+
'stc/common' : 'Common_extensions'}
260259

261260
class CCodePrinter(CodePrinter):
262261
"""
@@ -682,32 +681,6 @@ def init_stc_container(self, expr, assignment_var):
682681
init = f'{container_name} = c_init({dtype}, {keyraw});\n'
683682
return init
684683

685-
def invalidate_stc_headers(self, imports):
686-
"""
687-
Invalidate STC headers when STC extension headers are present.
688-
689-
This function iterates over the list of imports and removes any targets
690-
from STC headers if the target is present in their corresponding
691-
STC extension headers.
692-
The STC extension headers take care of including the standard
693-
headers.
694-
695-
Parameters
696-
----------
697-
imports : list of Import
698-
The list of Import objects representing the header files to include.
699-
700-
Returns
701-
-------
702-
None
703-
The function modifies the `imports` list in-place.
704-
"""
705-
for imp in imports:
706-
if imp.source in stc_header_mapping:
707-
for imp2 in imports:
708-
if imp2.source == stc_header_mapping[imp.source]:
709-
imp2.remove_target(imp.target)
710-
711684
def rename_imported_methods(self, expr):
712685
"""
713686
Rename class methods from user-defined imports.
@@ -773,7 +746,9 @@ def _print_PythonMinMax(self, expr):
773746
elif len(arg) > 2 and isinstance(arg.dtype.primitive_type, (PrimitiveFloatingPointType, PrimitiveIntegerType)):
774747
key = self.get_declare_type(arg[0])
775748
self.add_import(Import('stc/common', AsName(VariableTypeAnnotation(arg.dtype), key)))
776-
self.add_import(Import('Common_extensions', AsName(VariableTypeAnnotation(arg.dtype), key)))
749+
self.add_import(Import('Common_extensions',
750+
AsName(VariableTypeAnnotation(arg.dtype), key),
751+
ignore_at_print=True))
777752
return f'{key}_{expr.name}({len(arg)}, {", ".join(self._print(a) for a in arg)})'
778753
else:
779754
return errors.report(f"{expr.name} in C does not support arguments of type {arg.dtype}", symbol=expr,
@@ -875,7 +850,6 @@ def _print_ModuleHeader(self, expr):
875850

876851
# Print imports last to be sure that all additional_imports have been collected
877852
imports = [*expr.module.imports, *self._additional_imports.values()]
878-
self.invalidate_stc_headers(imports)
879853
imports = ''.join(self._print(i) for i in imports)
880854

881855
self._in_header = False
@@ -1052,7 +1026,7 @@ def _print_Import(self, expr):
10521026
source = source.name[-1].python_value
10531027
else:
10541028
source = self._print(source)
1055-
if source == 'Common_extensions':
1029+
if source == 'stc/common':
10561030
code = ''
10571031
for t in expr.target:
10581032
element_decl = f'#define i_key {t.local_alias}\n'
@@ -1061,11 +1035,11 @@ def _print_Import(self, expr):
10611035
code += ''.join((f'#ifndef {header_guard}\n',
10621036
f'#define {header_guard}\n',
10631037
element_decl,
1064-
f'#include <{stc_header_mapping[source]}.h>\n',
10651038
f'#include <{source}.h>\n',
1039+
f'#include <{stc_extension_mapping[source]}.h>\n',
10661040
f'#endif // {header_guard}\n\n'))
10671041
return code
1068-
elif source.startswith('stc/') or source in import_header_guard_prefix:
1042+
elif source.startswith('stc/'):
10691043
code = ''
10701044
for t in expr.target:
10711045
class_type = t.object.class_type
@@ -1087,11 +1061,12 @@ def _print_Import(self, expr):
10871061
f'#define {header_guard}\n',
10881062
f'#define i_type {container_type}\n',
10891063
element_decl,
1090-
'#define i_more\n' if source in import_header_guard_prefix else '',
1091-
f'#include <{stc_header_mapping[source]}.h>\n' if source in import_header_guard_prefix else '',
1092-
f'#include <{source}.h>\n',
1064+
'#define i_more\n' if source in stc_extension_mapping else '',
1065+
f'#include <{source}.h>\n',
1066+
f'#include <{stc_extension_mapping[source]}.h>\n' if source in stc_extension_mapping else '',
10931067
f'#endif // {header_guard}\n\n'))
10941068
return code
1069+
10951070
# Get with a default value is not used here as it is
10961071
# slower and on most occasions the import will not be in the
10971072
# dictionary
@@ -1316,6 +1291,9 @@ def get_c_type(self, dtype):
13161291
element_type = self.get_c_type(dtype.element_type).replace(' ', '_')
13171292
i_type = f'{container_type}_{element_type}'
13181293
self.add_import(Import(f'stc/{container_type}', AsName(VariableTypeAnnotation(dtype), i_type)))
1294+
self.add_import(Import(f'{stc_extension_mapping["stc/" + container_type]}',
1295+
AsName(VariableTypeAnnotation(dtype), i_type),
1296+
ignore_at_print=True))
13191297
return i_type
13201298
elif isinstance(dtype, DictType):
13211299
container_type = 'hmap'
@@ -2664,7 +2642,6 @@ def _print_Program(self, expr):
26642642
decs = ''.join(self._print(Declare(v)) for v in variables)
26652643

26662644
imports = [*expr.imports, *self._additional_imports.values()]
2667-
self.invalidate_stc_headers(imports)
26682645
imports = ''.join(self._print(i) for i in imports)
26692646

26702647
self.exit_scope()
@@ -2706,7 +2683,10 @@ def _print_ListPop(self, expr):
27062683
c_type = self.get_c_type(class_type)
27072684
list_obj = self._print(ObjectAddress(expr.list_obj))
27082685
if expr.index_element:
2709-
self.add_import(Import('List_extensions', AsName(VariableTypeAnnotation(class_type), c_type)))
2686+
self.add_import(Import('stc/vec', AsName(VariableTypeAnnotation(class_type), c_type)))
2687+
self.add_import(Import('List_extensions',
2688+
AsName(VariableTypeAnnotation(class_type), c_type),
2689+
ignore_at_print=True))
27102690
if is_literal_integer(expr.index_element) and int(expr.index_element) < 0:
27112691
idx_code = self._print(PyccelAdd(PythonLen(expr.list_obj), expr.index_element, simplify=True))
27122692
else:
@@ -2720,7 +2700,10 @@ def _print_ListPop(self, expr):
27202700
def _print_SetPop(self, expr):
27212701
dtype = expr.set_variable.class_type
27222702
var_type = self.get_c_type(dtype)
2723-
self.add_import(Import('Set_extensions', AsName(VariableTypeAnnotation(dtype), var_type)))
2703+
self.add_import(Import('stc/hset', AsName(VariableTypeAnnotation(dtype), var_type)))
2704+
self.add_import(Import('Set_extensions',
2705+
AsName(VariableTypeAnnotation(dtype), var_type),
2706+
ignore_at_print=True))
27242707
set_var = self._print(ObjectAddress(expr.set_variable))
27252708
return f'{var_type}_pop({set_var})'
27262709

@@ -2747,7 +2730,10 @@ def _print_SetUnion(self, expr):
27472730
severity='error', symbol=expr)
27482731
class_type = expr.set_variable.class_type
27492732
var_type = self.get_c_type(class_type)
2750-
self.add_import(Import('Set_extensions', AsName(VariableTypeAnnotation(class_type), var_type)))
2733+
self.add_import(Import('stc/hset', AsName(VariableTypeAnnotation(class_type), var_type)))
2734+
self.add_import(Import('Set_extensions',
2735+
AsName(VariableTypeAnnotation(class_type), var_type),
2736+
ignore_at_print=True))
27512737
set_var = self._print(ObjectAddress(expr.set_variable))
27522738
args = ', '.join([str(len(expr.args)), *(self._print(ObjectAddress(a)) for a in expr.args)])
27532739
return f'{var_type}_union({set_var}, {args})'

pyccel/codegen/printing/cwrappercode.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,6 @@ def _print_ModuleHeader(self, expr):
267267
imports = [*module_imports, *mod.imports]
268268
for i in imports:
269269
self.add_import(i)
270-
self.invalidate_stc_headers(imports)
271270
imports = ''.join(self._print(i) for i in imports)
272271

273272
function_signatures = ''.join(self.function_signature(f, print_arg_names = False) + ';\n' for f in mod.external_funcs)
@@ -374,7 +373,6 @@ def _print_PyModule(self, expr):
374373

375374
pymod_name = f'{expr.name}_wrapper'
376375
imports = [Import(pymod_name, Module(pymod_name,(),())), *self._additional_imports.values()]
377-
self.invalidate_stc_headers(imports)
378376
imports = ''.join(self._print(i) for i in imports)
379377

380378
self.exit_scope()

0 commit comments

Comments
 (0)