249
249
'stdbool' ,
250
250
'assert' ]}
251
251
252
- import_header_guard_prefix = {'Set_extensions ' : '_TOOLS_SET' ,
253
- 'List_extensions ' : '_TOOLS_LIST' ,
254
- 'Common_extensions ' : '_TOOLS_COMMON' }
252
+ import_header_guard_prefix = {'stc/hset ' : '_TOOLS_SET' ,
253
+ 'stc/vec ' : '_TOOLS_LIST' ,
254
+ 'stc/common ' : '_TOOLS_COMMON' }
255
255
256
-
257
- stc_header_mapping = {'List_extensions' : 'stc/vec' ,
258
- 'Set_extensions' : 'stc/hset' ,
259
- 'Common_extensions' : 'stc/common' }
256
+ stc_extension_mapping = {'stc/vec' : 'List_extensions' ,
257
+ 'stc/hset' : 'Set_extensions' ,
258
+ 'stc/common' : 'Common_extensions' }
260
259
261
260
class CCodePrinter (CodePrinter ):
262
261
"""
@@ -682,32 +681,6 @@ def init_stc_container(self, expr, assignment_var):
682
681
init = f'{ container_name } = c_init({ dtype } , { keyraw } );\n '
683
682
return init
684
683
685
- def invalidate_stc_headers (self , imports ):
686
- """
687
- Invalidate STC headers when STC extension headers are present.
688
-
689
- This function iterates over the list of imports and removes any targets
690
- from STC headers if the target is present in their corresponding
691
- STC extension headers.
692
- The STC extension headers take care of including the standard
693
- headers.
694
-
695
- Parameters
696
- ----------
697
- imports : list of Import
698
- The list of Import objects representing the header files to include.
699
-
700
- Returns
701
- -------
702
- None
703
- The function modifies the `imports` list in-place.
704
- """
705
- for imp in imports :
706
- if imp .source in stc_header_mapping :
707
- for imp2 in imports :
708
- if imp2 .source == stc_header_mapping [imp .source ]:
709
- imp2 .remove_target (imp .target )
710
-
711
684
def rename_imported_methods (self , expr ):
712
685
"""
713
686
Rename class methods from user-defined imports.
@@ -773,7 +746,9 @@ def _print_PythonMinMax(self, expr):
773
746
elif len (arg ) > 2 and isinstance (arg .dtype .primitive_type , (PrimitiveFloatingPointType , PrimitiveIntegerType )):
774
747
key = self .get_declare_type (arg [0 ])
775
748
self .add_import (Import ('stc/common' , AsName (VariableTypeAnnotation (arg .dtype ), key )))
776
- self .add_import (Import ('Common_extensions' , AsName (VariableTypeAnnotation (arg .dtype ), key )))
749
+ self .add_import (Import ('Common_extensions' ,
750
+ AsName (VariableTypeAnnotation (arg .dtype ), key ),
751
+ ignore_at_print = True ))
777
752
return f'{ key } _{ expr .name } ({ len (arg )} , { ", " .join (self ._print (a ) for a in arg )} )'
778
753
else :
779
754
return errors .report (f"{ expr .name } in C does not support arguments of type { arg .dtype } " , symbol = expr ,
@@ -875,7 +850,6 @@ def _print_ModuleHeader(self, expr):
875
850
876
851
# Print imports last to be sure that all additional_imports have been collected
877
852
imports = [* expr .module .imports , * self ._additional_imports .values ()]
878
- self .invalidate_stc_headers (imports )
879
853
imports = '' .join (self ._print (i ) for i in imports )
880
854
881
855
self ._in_header = False
@@ -1052,7 +1026,7 @@ def _print_Import(self, expr):
1052
1026
source = source .name [- 1 ].python_value
1053
1027
else :
1054
1028
source = self ._print (source )
1055
- if source == 'Common_extensions ' :
1029
+ if source == 'stc/common ' :
1056
1030
code = ''
1057
1031
for t in expr .target :
1058
1032
element_decl = f'#define i_key { t .local_alias } \n '
@@ -1061,11 +1035,11 @@ def _print_Import(self, expr):
1061
1035
code += '' .join ((f'#ifndef { header_guard } \n ' ,
1062
1036
f'#define { header_guard } \n ' ,
1063
1037
element_decl ,
1064
- f'#include <{ stc_header_mapping [source ]} .h>\n ' ,
1065
1038
f'#include <{ source } .h>\n ' ,
1039
+ f'#include <{ stc_extension_mapping [source ]} .h>\n ' ,
1066
1040
f'#endif // { header_guard } \n \n ' ))
1067
1041
return code
1068
- elif source .startswith ('stc/' ) or source in import_header_guard_prefix :
1042
+ elif source .startswith ('stc/' ):
1069
1043
code = ''
1070
1044
for t in expr .target :
1071
1045
class_type = t .object .class_type
@@ -1087,11 +1061,12 @@ def _print_Import(self, expr):
1087
1061
f'#define { header_guard } \n ' ,
1088
1062
f'#define i_type { container_type } \n ' ,
1089
1063
element_decl ,
1090
- '#define i_more\n ' if source in import_header_guard_prefix else '' ,
1091
- f'#include <{ stc_header_mapping [ source ] } .h>\n ' if source in import_header_guard_prefix else ' ' ,
1092
- f'#include <{ source } .h>\n ' ,
1064
+ '#define i_more\n ' if source in stc_extension_mapping else '' ,
1065
+ f'#include <{ source } .h>\n ' ,
1066
+ f'#include <{ stc_extension_mapping [ source ] } .h>\n ' if source in stc_extension_mapping else '' ,
1093
1067
f'#endif // { header_guard } \n \n ' ))
1094
1068
return code
1069
+
1095
1070
# Get with a default value is not used here as it is
1096
1071
# slower and on most occasions the import will not be in the
1097
1072
# dictionary
@@ -1316,6 +1291,9 @@ def get_c_type(self, dtype):
1316
1291
element_type = self .get_c_type (dtype .element_type ).replace (' ' , '_' )
1317
1292
i_type = f'{ container_type } _{ element_type } '
1318
1293
self .add_import (Import (f'stc/{ container_type } ' , AsName (VariableTypeAnnotation (dtype ), i_type )))
1294
+ self .add_import (Import (f'{ stc_extension_mapping ["stc/" + container_type ]} ' ,
1295
+ AsName (VariableTypeAnnotation (dtype ), i_type ),
1296
+ ignore_at_print = True ))
1319
1297
return i_type
1320
1298
elif isinstance (dtype , DictType ):
1321
1299
container_type = 'hmap'
@@ -2664,7 +2642,6 @@ def _print_Program(self, expr):
2664
2642
decs = '' .join (self ._print (Declare (v )) for v in variables )
2665
2643
2666
2644
imports = [* expr .imports , * self ._additional_imports .values ()]
2667
- self .invalidate_stc_headers (imports )
2668
2645
imports = '' .join (self ._print (i ) for i in imports )
2669
2646
2670
2647
self .exit_scope ()
@@ -2706,7 +2683,10 @@ def _print_ListPop(self, expr):
2706
2683
c_type = self .get_c_type (class_type )
2707
2684
list_obj = self ._print (ObjectAddress (expr .list_obj ))
2708
2685
if expr .index_element :
2709
- self .add_import (Import ('List_extensions' , AsName (VariableTypeAnnotation (class_type ), c_type )))
2686
+ self .add_import (Import ('stc/vec' , AsName (VariableTypeAnnotation (class_type ), c_type )))
2687
+ self .add_import (Import ('List_extensions' ,
2688
+ AsName (VariableTypeAnnotation (class_type ), c_type ),
2689
+ ignore_at_print = True ))
2710
2690
if is_literal_integer (expr .index_element ) and int (expr .index_element ) < 0 :
2711
2691
idx_code = self ._print (PyccelAdd (PythonLen (expr .list_obj ), expr .index_element , simplify = True ))
2712
2692
else :
@@ -2720,7 +2700,10 @@ def _print_ListPop(self, expr):
2720
2700
def _print_SetPop (self , expr ):
2721
2701
dtype = expr .set_variable .class_type
2722
2702
var_type = self .get_c_type (dtype )
2723
- self .add_import (Import ('Set_extensions' , AsName (VariableTypeAnnotation (dtype ), var_type )))
2703
+ self .add_import (Import ('stc/hset' , AsName (VariableTypeAnnotation (dtype ), var_type )))
2704
+ self .add_import (Import ('Set_extensions' ,
2705
+ AsName (VariableTypeAnnotation (dtype ), var_type ),
2706
+ ignore_at_print = True ))
2724
2707
set_var = self ._print (ObjectAddress (expr .set_variable ))
2725
2708
return f'{ var_type } _pop({ set_var } )'
2726
2709
@@ -2747,7 +2730,10 @@ def _print_SetUnion(self, expr):
2747
2730
severity = 'error' , symbol = expr )
2748
2731
class_type = expr .set_variable .class_type
2749
2732
var_type = self .get_c_type (class_type )
2750
- self .add_import (Import ('Set_extensions' , AsName (VariableTypeAnnotation (class_type ), var_type )))
2733
+ self .add_import (Import ('stc/hset' , AsName (VariableTypeAnnotation (class_type ), var_type )))
2734
+ self .add_import (Import ('Set_extensions' ,
2735
+ AsName (VariableTypeAnnotation (class_type ), var_type ),
2736
+ ignore_at_print = True ))
2751
2737
set_var = self ._print (ObjectAddress (expr .set_variable ))
2752
2738
args = ', ' .join ([str (len (expr .args )), * (self ._print (ObjectAddress (a )) for a in expr .args )])
2753
2739
return f'{ var_type } _union({ set_var } , { args } )'
0 commit comments