@@ -77,7 +77,6 @@ class PythonCodePrinter(CodePrinter):
77
77
def __init__ (self , filename ):
78
78
errors .set_target (filename )
79
79
super ().__init__ ()
80
- self ._additional_imports = {}
81
80
self ._aliases = {}
82
81
self ._ignore_funcs = []
83
82
@@ -92,24 +91,6 @@ def _indent_codestring(self, lines):
92
91
def _format_code (self , lines ):
93
92
return lines
94
93
95
- def get_additional_imports (self ):
96
- """return the additional imports collected in printing stage"""
97
- imports = [i for tup in self ._additional_imports .values () for i in tup [1 ]]
98
- return imports
99
-
100
- def insert_new_import (self , source , target , alias = None ):
101
- """ Add an import of an object which may have been
102
- added by pyccel and therefore may not have been imported
103
- """
104
- if alias and alias != target :
105
- target = AsName (target , alias )
106
- import_obj = Import (source , target )
107
- source = str (source )
108
- src_info = self ._additional_imports .setdefault (source , (set (), []))
109
- if any (i not in src_info [0 ] for i in import_obj .target ):
110
- src_info [0 ].update (import_obj .target )
111
- src_info [1 ].append (import_obj )
112
-
113
94
def _find_functional_expr_and_iterables (self , expr ):
114
95
"""
115
96
Traverse through the loop representing a FunctionalFor or GeneratorComprehension
@@ -180,9 +161,7 @@ def _get_numpy_name(self, expr):
180
161
type_name = expr .name
181
162
name = self ._aliases .get (cls , type_name )
182
163
if name == type_name and cls not in (PythonBool , PythonInt , PythonFloat , PythonComplex ):
183
- self .insert_new_import (
184
- source = 'numpy' ,
185
- target = AsName (cls , name ))
164
+ self .add_import (Import ('numpy' , [AsName (cls , name )]))
186
165
return name
187
166
188
167
#----------------------------------------------------------------------
@@ -334,7 +313,7 @@ def _print_FunctionDef(self, expr):
334
313
decorators .pop ('template' )
335
314
for n ,f in decorators .items ():
336
315
if n in pyccel_decorators :
337
- self .insert_new_import ( DottedName ('pyccel.decorators' ), AsName (decorators_mod [n ], n ))
316
+ self .add_import ( Import ( DottedName ('pyccel.decorators' ), [ AsName (decorators_mod [n ], n )] ))
338
317
# TODO - All decorators must be stored in a list
339
318
if not isinstance (f , list ):
340
319
f = [f ]
@@ -394,7 +373,7 @@ def _print_Program(self, expr):
394
373
module = modules [0 ]
395
374
imports = '' .join (self ._print (i ) for i in expr .imports if i .source_module is not module )
396
375
body = self ._print (expr .body )
397
- imports += '' .join (self ._print (i ) for i in self .get_additional_imports ())
376
+ imports += '' .join (self ._print (i ) for i in self ._additional_imports . values ())
398
377
399
378
body = imports + body
400
379
body = self ._indent_codestring (body )
@@ -635,9 +614,7 @@ def _print_FunctionalFor(self, expr):
635
614
636
615
name = self ._aliases .get (type (expr ),'array' )
637
616
if name == 'array' :
638
- self .insert_new_import (
639
- source = 'numpy' ,
640
- target = AsName (NumpyArray , 'array' ))
617
+ self .add_import (Import ('numpy' , [AsName (NumpyArray , 'array' )]))
641
618
642
619
return '{} = {}([{} {}])\n ' .format (lhs , name , body , for_loops )
643
620
@@ -814,18 +791,14 @@ def _print_NumpyNorm(self, expr):
814
791
def _print_NumpyNonZero (self , expr ):
815
792
name = self ._aliases .get (type (expr ),'nonzero' )
816
793
if name == 'nonzero' :
817
- self .insert_new_import (
818
- source = 'numpy' ,
819
- target = AsName (NumpyNonZero , 'nonzero' ))
794
+ self .add_import (Import ('numpy' , [AsName (NumpyNonZero , 'nonzero' )]))
820
795
arg = self ._print (expr .array )
821
796
return "{}({})" .format (name , arg )
822
797
823
798
def _print_NumpyCountNonZero (self , expr ):
824
799
name = self ._aliases .get (type (expr ),'count_nonzero' )
825
800
if name == 'count_nonzero' :
826
- self .insert_new_import (
827
- source = 'numpy' ,
828
- target = AsName (NumpyNonZero , 'count_nonzero' ))
801
+ self .add_import (Import ('numpy' , [AsName (NumpyNonZero , 'count_nonzero' )]))
829
802
830
803
axis_arg = expr .axis
831
804
@@ -933,9 +906,7 @@ def _print_Literal(self, expr):
933
906
cast_name = cast_func .name
934
907
name = self ._aliases .get (cast_func , cast_name )
935
908
if is_numpy and name == cast_name :
936
- self .insert_new_import (
937
- source = 'numpy' ,
938
- target = AsName (cast_func , cast_name ))
909
+ self .add_import (Import ('numpy' , [AsName (cast_func , cast_name )]))
939
910
return '{}({})' .format (name , repr (expr .python_value ))
940
911
else :
941
912
return repr (expr .python_value )
@@ -983,7 +954,7 @@ def _print_Module(self, expr):
983
954
if free_func :
984
955
self ._ignore_funcs .append (free_func )
985
956
986
- imports += '' .join (self ._print (i ) for i in self .get_additional_imports ())
957
+ imports += '' .join (self ._print (i ) for i in self ._additional_imports . values ())
987
958
988
959
body = '' .join ((interfaces , funcs , classes , init_body ))
989
960
0 commit comments