|
46 | 46 | from pyccel.ast.numpytypes import NumpyFloat32Type, NumpyFloat64Type, NumpyComplex64Type, NumpyComplex128Type
|
47 | 47 | from pyccel.ast.numpytypes import NumpyNDArrayType, numpy_precision_map
|
48 | 48 |
|
| 49 | +from pyccel.ast.type_annotations import VariableTypeAnnotation |
| 50 | + |
49 | 51 | from pyccel.ast.utilities import expand_to_loops
|
50 | 52 |
|
51 | 53 | from pyccel.ast.variable import IndexedElement
|
|
239 | 241 | 'assert',
|
240 | 242 | 'numpy_c']}
|
241 | 243 |
|
| 244 | +import_header_guard_prefix = {'Set_extensions' : '_TOOLS_SET'} |
| 245 | + |
242 | 246 | class CCodePrinter(CodePrinter):
|
243 | 247 | """
|
244 | 248 | A printer for printing code in C.
|
@@ -312,8 +316,22 @@ def get_additional_imports(self):
|
312 | 316 | return self._additional_imports.keys()
|
313 | 317 |
|
314 | 318 | def add_import(self, import_obj):
|
| 319 | + """ |
| 320 | + Add a new import to the current context. |
| 321 | +
|
| 322 | + Add a new import to the current context. This allows the import to be recognised |
| 323 | + at the compiling/linking stage. If the source of the import is not new then any |
| 324 | + new targets are added to the Import object. |
| 325 | +
|
| 326 | + Parameters |
| 327 | + ---------- |
| 328 | + import_obj : Import |
| 329 | + The AST node describing the import. |
| 330 | + """ |
315 | 331 | if import_obj.source not in self._additional_imports:
|
316 | 332 | self._additional_imports[import_obj.source] = import_obj
|
| 333 | + elif import_obj.target: |
| 334 | + self._additional_imports[import_obj.source].define_target(import_obj.target) |
317 | 335 |
|
318 | 336 | def _get_statement(self, codestring):
|
319 | 337 | return "%s;\n" % codestring
|
@@ -996,24 +1014,21 @@ def _print_Import(self, expr):
|
996 | 1014 | source = source.name[-1]
|
997 | 1015 | else:
|
998 | 1016 | source = self._print(source)
|
999 |
| - if source.startswith('stc/'): |
1000 |
| - stc_name, container_type, container_key = source.split("/") |
1001 |
| - container = container_type.split("_") |
1002 |
| - return '\n'.join((f'#ifndef _{container_type.upper()}', |
1003 |
| - f'#define _{container_type.upper()}', |
1004 |
| - f'#define i_type {container_type}', |
1005 |
| - f'#define i_key {container_key}', |
1006 |
| - f'#include "{stc_name + "/" + container[0]}.h"', |
1007 |
| - '#endif\n')) |
1008 |
| - elif source.startswith('Set_pop'): |
1009 |
| - _ , i_type, i_key = source.split('/') |
1010 |
| - self.add_import(Import('STC_Extensions', Module('STC_Extensions', (), ()))) |
1011 |
| - return '\n'.join(( |
1012 |
| - f'#ifndef TOOLS_SET_{str(i_key).upper()}\n' |
1013 |
| - f'#define TOOLS_SET_{str(i_key).upper()}\n' |
1014 |
| - f'#define i_type {i_type}', |
1015 |
| - f'#define i_key {i_key}\n', |
1016 |
| - '#include "Set_extensions.h"\n#endif\n')) |
| 1017 | + if source.startswith('stc/') or source in import_header_guard_prefix: |
| 1018 | + code = '' |
| 1019 | + for t in expr.target: |
| 1020 | + dtype = t.object.class_type |
| 1021 | + container_type = t.target |
| 1022 | + container_key = self.get_c_type(dtype.element_type) |
| 1023 | + header_guard_prefix = import_header_guard_prefix.get(source, '') |
| 1024 | + header_guard = f'{header_guard_prefix}_{container_type.upper()}' |
| 1025 | + code += (f'#ifndef {header_guard}\n' |
| 1026 | + f'#define {header_guard}\n' |
| 1027 | + f'#define i_type {container_type}\n' |
| 1028 | + f'#define i_key {container_key}\n' |
| 1029 | + f'#include <{source}.h>\n' |
| 1030 | + f'#endif // {header_guard}\n\n') |
| 1031 | + return code |
1017 | 1032 | # Get with a default value is not used here as it is
|
1018 | 1033 | # slower and on most occasions the import will not be in the
|
1019 | 1034 | # dictionary
|
@@ -1231,10 +1246,10 @@ def get_c_type(self, dtype):
|
1231 | 1246 |
|
1232 | 1247 | key = (primitive_type, dtype.precision)
|
1233 | 1248 | elif isinstance(dtype, (HomogeneousSetType, HomogeneousListType)):
|
1234 |
| - container_type = 'hset_' if dtype.name == 'set' else 'vec_' |
1235 |
| - element_type = self.get_c_type(dtype.element_type) |
1236 |
| - i_type = container_type + element_type.replace(' ', '_') |
1237 |
| - self.add_import(Import(f'stc/{i_type}/{element_type}', Module(f'stc/{i_type}', (), ()))) |
| 1249 | + container_type = 'hset' if dtype.name == 'set' else 'vec' |
| 1250 | + element_type = self.get_c_type(dtype.element_type).replace(' ', '_') |
| 1251 | + i_type = f'{container_type}_{element_type}' |
| 1252 | + self.add_import(Import(f'stc/{container_type}', AsName(VariableTypeAnnotation(dtype), i_type))) |
1238 | 1253 | return i_type
|
1239 | 1254 | else:
|
1240 | 1255 | key = dtype
|
@@ -2193,9 +2208,9 @@ def _print_Assign(self, expr):
|
2193 | 2208 | return prefix_code+'{} = {};\n'.format(lhs, rhs)
|
2194 | 2209 |
|
2195 | 2210 | def _print_SetPop(self, expr):
|
2196 |
| - var_type = self.get_declare_type(expr.set_variable) |
2197 |
| - element_type = self.get_c_type(expr.set_variable.class_type.element_type) |
2198 |
| - self.add_import(Import(f'Set_pop_macro/{var_type}/{element_type}', Module(f'Set_pop_macro/{var_type}/{element_type}', (), ()))) |
| 2211 | + dtype = expr.set_variable.class_type |
| 2212 | + var_type = self.get_c_type(dtype) |
| 2213 | + self.add_import(Import('Set_extensions', AsName(VariableTypeAnnotation(dtype), var_type))) |
2199 | 2214 | set_var = self._print(ObjectAddress(expr.set_variable))
|
2200 | 2215 | return f'{var_type}_pop({set_var})'
|
2201 | 2216 |
|
|
0 commit comments