Skip to content

Commit 3b3c327

Browse files
authored
Unify import strategy (pyccel#1975)
Unify the strategy for handling additional imports in the printing stage for different languages. This allows `get_additional_imports` and `add_import` to be moved to the superclass `codegen.printing.CodePrinter` and will make it easier to handle different import strategies (e.g. pyccel#1657) **Commit Summary** - Move `get_additional_imports` to `codegen.printing.CodePrinter` - Move `add_import` to `codegen.printing.CodePrinter` - Use a dictionary for imports in `codegen.printing.FCodePrinter` to match what is done in `codegen.printing.CCodePrinter` - Change the storage type of the `_additional_imports` dict in `codegen.printing.PythonCodePrinter` to match what is done in `codegen.printing.CCodePrinter`
1 parent f508fa3 commit 3b3c327

File tree

8 files changed

+55
-73
lines changed

8 files changed

+55
-73
lines changed

.github/workflows/anaconda_linux.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ jobs:
108108
name: coverage-artifact
109109
path: .coverage
110110
retention-days: 1
111+
include-hidden-files: true
111112
- name: "Post completed"
112113
if: always() && github.event_name != 'push'
113114
run:

.github/workflows/intel.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ jobs:
102102
name: coverage-artifact
103103
path: .coverage
104104
retention-days: 1
105+
include-hidden-files: true
105106
- name: "Post completed"
106107
if: always() && github.event_name != 'push'
107108
run:

.github/workflows/linux.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ jobs:
105105
name: coverage-artifact-${{ matrix.python_version }}
106106
path: .coverage
107107
retention-days: 3
108+
include-hidden-files: true
108109
- name: "Post completed"
109110
if: always() && github.event_name != 'push'
110111
run:

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ All notable changes to this project will be documented in this file.
8787
- \[INTERNALS\] Remove unnecessary argument `kind` to `Errors.set_target`.
8888
- \[INTERNALS\] Handle STC imports with Pyccel objects.
8989
- \[INTERNALS\] Stop using ndarrays as an intermediate step to return arrays from Fortran code.
90+
- \[INTERNALS\] Unify the strategy for handling additional imports in the printing stage for different languages.
9091

9192
### Deprecated
9293

pyccel/codegen/printing/ccode.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -314,28 +314,6 @@ def __init__(self, filename, prefix_module = None):
314314
self._current_module = None
315315
self._in_header = False
316316

317-
def get_additional_imports(self):
318-
"""return the additional imports collected in printing stage"""
319-
return self._additional_imports.keys()
320-
321-
def add_import(self, import_obj):
322-
"""
323-
Add a new import to the current context.
324-
325-
Add a new import to the current context. This allows the import to be recognised
326-
at the compiling/linking stage. If the source of the import is not new then any
327-
new targets are added to the Import object.
328-
329-
Parameters
330-
----------
331-
import_obj : Import
332-
The AST node describing the import.
333-
"""
334-
if import_obj.source not in self._additional_imports:
335-
self._additional_imports[import_obj.source] = import_obj
336-
elif import_obj.target:
337-
self._additional_imports[import_obj.source].define_target(import_obj.target)
338-
339317
def _format_code(self, lines):
340318
return self.indent_code(lines)
341319

pyccel/codegen/printing/codeprinter.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ class CodePrinter:
2929
language = None
3030
def __init__(self):
3131
self._scope = None
32+
self._additional_imports = {}
3233

3334
def doprint(self, expr):
3435
"""
@@ -54,6 +55,38 @@ def doprint(self, expr):
5455
# Format the output
5556
return ''.join(self._format_code(lines))
5657

58+
def get_additional_imports(self):
59+
"""
60+
Get any additional imports collected during the printing stage.
61+
62+
Get any additional imports collected during the printing stage.
63+
This is necessary to correctly compile the files.
64+
65+
Returns
66+
-------
67+
iterable[str]
68+
An iterable of the include strings.
69+
"""
70+
return self._additional_imports.keys()
71+
72+
def add_import(self, import_obj):
73+
"""
74+
Add a new import to the current context.
75+
76+
Add a new import to the current context. This allows the import to be recognised
77+
at the compiling/linking stage. If the source of the import is not new then any
78+
new targets are added to the Import object.
79+
80+
Parameters
81+
----------
82+
import_obj : Import
83+
The AST node describing the import.
84+
"""
85+
if import_obj.source not in self._additional_imports:
86+
self._additional_imports[import_obj.source] = import_obj
87+
elif import_obj.target:
88+
self._additional_imports[import_obj.source].define_target(import_obj.target)
89+
5790
@property
5891
def scope(self):
5992
""" Return the scope associated with the object being printed

pyccel/codegen/printing/fcode.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,6 @@ def __init__(self, filename, prefix_module = None):
251251
self._current_class = None
252252

253253
self._additional_code = None
254-
self._additional_imports = set()
255254

256255
self.prefix_module = prefix_module
257256

@@ -268,10 +267,6 @@ def print_constant_imports(self):
268267
macros.append(macro)
269268
return "\n".join(macros)
270269

271-
def get_additional_imports(self):
272-
"""return the additional modules collected for importing in printing stage"""
273-
return [i.source for i in self._additional_imports]
274-
275270
def set_current_class(self, name):
276271

277272
self._current_class = name
@@ -438,13 +433,14 @@ def _handle_inline_func_call(self, expr, assign_lhs = None):
438433
func.reinstate_presence_checks()
439434
func.swap_out_args()
440435

441-
self._additional_imports.update(func.imports)
436+
for i in func.imports:
437+
self.add_import(i)
442438
if func.global_vars or func.global_funcs:
443439
mod = func.get_direct_user_nodes(lambda x: isinstance(x, Module))[0]
444440
current_mod = expr.get_user_nodes(Module, excluded_nodes=(FunctionCall,))[0]
445441
if current_mod is not mod:
446-
self._additional_imports.add(Import(mod.name, [AsName(v, v.name) \
447-
for v in (*func.global_vars, *func.global_funcs)]))
442+
self.add_import(Import(mod.name, [AsName(v, v.name) \
443+
for v in (*func.global_vars, *func.global_funcs)]))
448444
for v in (*func.global_vars, *func.global_funcs):
449445
self.scope.insert_symbol(v.name)
450446

@@ -560,7 +556,7 @@ def _print_Module(self, expr):
560556
# ...
561557

562558
contains = 'contains\n' if (expr.funcs or expr.classes or expr.interfaces) else ''
563-
imports += ''.join(self._print(i) for i in self._additional_imports)
559+
imports += ''.join(self._print(i) for i in self._additional_imports.values())
564560
imports += "\n" + self.print_constant_imports()
565561
parts = ['module {}\n'.format(name),
566562
imports,
@@ -604,7 +600,7 @@ def _print_Program(self, expr):
604600

605601
decs += '\ninteger :: ierr = -1' +\
606602
'\ninteger, allocatable :: status (:)'
607-
imports += ''.join(self._print(i) for i in self._additional_imports)
603+
imports += ''.join(self._print(i) for i in self._additional_imports.values())
608604
imports += "\n" + self.print_constant_imports()
609605
parts = ['program {}\n'.format(name),
610606
imports,
@@ -1379,7 +1375,7 @@ def _print_NumpyAmax(self, expr):
13791375
arg_code = self._print(array_arg)
13801376

13811377
if isinstance(array_arg.dtype.primitive_type, PrimitiveComplexType):
1382-
self._additional_imports.add(Import('pyc_math_f90', Module('pyc_math_f90',(),())))
1378+
self.add_import(Import('pyc_math_f90', Module('pyc_math_f90',(),())))
13831379
return f'amax({array_arg})'
13841380
else:
13851381
return f'maxval({arg_code})'
@@ -1392,7 +1388,7 @@ def _print_NumpyAmin(self, expr):
13921388
arg_code = self._print(array_arg)
13931389

13941390
if isinstance(array_arg.dtype.primitive_type, PrimitiveComplexType):
1395-
self._additional_imports.add(Import('pyc_math_f90', Module('pyc_math_f90',(),())))
1391+
self.add_import(Import('pyc_math_f90', Module('pyc_math_f90',(),())))
13961392
return f'amin({array_arg})'
13971393
else:
13981394
return f'minval({arg_code})'
@@ -2774,7 +2770,7 @@ def _print_NumpySign(self, expr):
27742770
arg_code = self._print(arg)
27752771
if isinstance(expr.dtype.primitive_type, PrimitiveComplexType):
27762772
func = PyccelFunctionDef('numpy_sign', NumpySign)
2777-
self._additional_imports.add(Import('numpy_f90', AsName(func, 'numpy_sign')))
2773+
self.add_import(Import('numpy_f90', AsName(func, 'numpy_sign')))
27782774
return f'numpy_sign({arg_code})'
27792775
else:
27802776
cast_func = DtypePrecisionToCastFunction[expr.dtype]
@@ -2822,7 +2818,7 @@ def _print_MathFunctionBase(self, expr):
28222818
except KeyError:
28232819
errors.report(PYCCEL_RESTRICTION_TODO, severity='fatal')
28242820
if func_name.startswith("pyc"):
2825-
self._additional_imports.add(Import('pyc_math_f90', Module('pyc_math_f90',(),())))
2821+
self.add_import(Import('pyc_math_f90', Module('pyc_math_f90',(),())))
28262822
args = []
28272823
for arg in expr.args:
28282824
if arg.dtype != expr.dtype:

pyccel/codegen/printing/pycode.py

Lines changed: 8 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,6 @@ class PythonCodePrinter(CodePrinter):
7777
def __init__(self, filename):
7878
errors.set_target(filename)
7979
super().__init__()
80-
self._additional_imports = {}
8180
self._aliases = {}
8281
self._ignore_funcs = []
8382

@@ -92,24 +91,6 @@ def _indent_codestring(self, lines):
9291
def _format_code(self, lines):
9392
return lines
9493

95-
def get_additional_imports(self):
96-
"""return the additional imports collected in printing stage"""
97-
imports = [i for tup in self._additional_imports.values() for i in tup[1]]
98-
return imports
99-
100-
def insert_new_import(self, source, target, alias = None):
101-
""" Add an import of an object which may have been
102-
added by pyccel and therefore may not have been imported
103-
"""
104-
if alias and alias!=target:
105-
target = AsName(target, alias)
106-
import_obj = Import(source, target)
107-
source = str(source)
108-
src_info = self._additional_imports.setdefault(source, (set(), []))
109-
if any(i not in src_info[0] for i in import_obj.target):
110-
src_info[0].update(import_obj.target)
111-
src_info[1].append(import_obj)
112-
11394
def _find_functional_expr_and_iterables(self, expr):
11495
"""
11596
Traverse through the loop representing a FunctionalFor or GeneratorComprehension
@@ -180,9 +161,7 @@ def _get_numpy_name(self, expr):
180161
type_name = expr.name
181162
name = self._aliases.get(cls, type_name)
182163
if name == type_name and cls not in (PythonBool, PythonInt, PythonFloat, PythonComplex):
183-
self.insert_new_import(
184-
source = 'numpy',
185-
target = AsName(cls, name))
164+
self.add_import(Import('numpy', [AsName(cls, name)]))
186165
return name
187166

188167
#----------------------------------------------------------------------
@@ -334,7 +313,7 @@ def _print_FunctionDef(self, expr):
334313
decorators.pop('template')
335314
for n,f in decorators.items():
336315
if n in pyccel_decorators:
337-
self.insert_new_import(DottedName('pyccel.decorators'), AsName(decorators_mod[n], n))
316+
self.add_import(Import(DottedName('pyccel.decorators'), [AsName(decorators_mod[n], n)]))
338317
# TODO - All decorators must be stored in a list
339318
if not isinstance(f, list):
340319
f = [f]
@@ -394,7 +373,7 @@ def _print_Program(self, expr):
394373
module = modules[0]
395374
imports = ''.join(self._print(i) for i in expr.imports if i.source_module is not module)
396375
body = self._print(expr.body)
397-
imports += ''.join(self._print(i) for i in self.get_additional_imports())
376+
imports += ''.join(self._print(i) for i in self._additional_imports.values())
398377

399378
body = imports+body
400379
body = self._indent_codestring(body)
@@ -635,9 +614,7 @@ def _print_FunctionalFor(self, expr):
635614

636615
name = self._aliases.get(type(expr),'array')
637616
if name == 'array':
638-
self.insert_new_import(
639-
source = 'numpy',
640-
target = AsName(NumpyArray, 'array'))
617+
self.add_import(Import('numpy', [AsName(NumpyArray, 'array')]))
641618

642619
return '{} = {}([{} {}])\n'.format(lhs, name, body, for_loops)
643620

@@ -814,18 +791,14 @@ def _print_NumpyNorm(self, expr):
814791
def _print_NumpyNonZero(self, expr):
815792
name = self._aliases.get(type(expr),'nonzero')
816793
if name == 'nonzero':
817-
self.insert_new_import(
818-
source = 'numpy',
819-
target = AsName(NumpyNonZero, 'nonzero'))
794+
self.add_import(Import('numpy', [AsName(NumpyNonZero, 'nonzero')]))
820795
arg = self._print(expr.array)
821796
return "{}({})".format(name, arg)
822797

823798
def _print_NumpyCountNonZero(self, expr):
824799
name = self._aliases.get(type(expr),'count_nonzero')
825800
if name == 'count_nonzero':
826-
self.insert_new_import(
827-
source = 'numpy',
828-
target = AsName(NumpyNonZero, 'count_nonzero'))
801+
self.add_import(Import('numpy', [AsName(NumpyNonZero, 'count_nonzero')]))
829802

830803
axis_arg = expr.axis
831804

@@ -933,9 +906,7 @@ def _print_Literal(self, expr):
933906
cast_name = cast_func.name
934907
name = self._aliases.get(cast_func, cast_name)
935908
if is_numpy and name == cast_name:
936-
self.insert_new_import(
937-
source = 'numpy',
938-
target = AsName(cast_func, cast_name))
909+
self.add_import(Import('numpy', [AsName(cast_func, cast_name)]))
939910
return '{}({})'.format(name, repr(expr.python_value))
940911
else:
941912
return repr(expr.python_value)
@@ -983,7 +954,7 @@ def _print_Module(self, expr):
983954
if free_func:
984955
self._ignore_funcs.append(free_func)
985956

986-
imports += ''.join(self._print(i) for i in self.get_additional_imports())
957+
imports += ''.join(self._print(i) for i in self._additional_imports.values())
987958

988959
body = ''.join((interfaces, funcs, classes, init_body))
989960

0 commit comments

Comments
 (0)