14
14
15
15
from pyccel .ast .bind_c import BindCPointer
16
16
17
- from pyccel .ast .builtins import PythonRange , PythonComplex
17
+ from pyccel .ast .builtins import PythonRange , PythonComplex , PythonMin
18
18
from pyccel .ast .builtins import PythonPrint , PythonType , VariableIterator
19
+
19
20
from pyccel .ast .builtins import PythonList , PythonTuple , PythonSet , PythonDict , PythonLen
20
21
21
22
from pyccel .ast .builtin_methods .dict_methods import DictItems
248
249
'stdbool' ,
249
250
'assert' ]}
250
251
251
- import_header_guard_prefix = {'Set_extensions' : '_TOOLS_SET' ,
252
- 'List_extensions' : '_TOOLS_LIST' }
252
+ import_header_guard_prefix = {'Set_extensions' : '_TOOLS_SET' ,
253
+ 'List_extensions' : '_TOOLS_LIST' ,
254
+ 'Common_extensions' : '_TOOLS_COMMON' }
255
+
256
+
257
+ stc_header_mapping = {'List_extensions' : 'stc/vec' ,
258
+ 'Set_extensions' : 'stc/hset' ,
259
+ 'Common_extensions' : 'stc/common' }
253
260
254
261
class CCodePrinter (CodePrinter ):
255
262
"""
@@ -675,6 +682,32 @@ def init_stc_container(self, expr, assignment_var):
675
682
init = f'{ container_name } = c_init({ dtype } , { keyraw } );\n '
676
683
return init
677
684
685
+ def invalidate_stc_headers (self , imports ):
686
+ """
687
+ Invalidate STC headers when STC extension headers are present.
688
+
689
+ This function iterates over the list of imports and removes any targets
690
+ from STC headers if the target is present in their corresponding
691
+ STC extension headers.
692
+ The STC extension headers take care of including the standard
693
+ headers.
694
+
695
+ Parameters
696
+ ----------
697
+ imports : list of Import
698
+ The list of Import objects representing the header files to include.
699
+
700
+ Returns
701
+ -------
702
+ None
703
+ The function modifies the `imports` list in-place.
704
+ """
705
+ for imp in imports :
706
+ if imp .source in stc_header_mapping :
707
+ for imp2 in imports :
708
+ if imp2 .source == stc_header_mapping [imp .source ]:
709
+ imp2 .remove_target (imp .target )
710
+
678
711
def rename_imported_methods (self , expr ):
679
712
"""
680
713
Rename class methods from user-defined imports.
@@ -711,12 +744,13 @@ def _print_PythonAbs(self, expr):
711
744
func = "labs"
712
745
return "{}({})" .format (func , self ._print (expr .arg ))
713
746
714
- def _print_PythonMin (self , expr ):
747
+ def _print_PythonMinMax (self , expr ):
715
748
arg = expr .args [0 ]
716
749
if arg .dtype .primitive_type is PrimitiveFloatingPointType () and len (arg ) == 2 :
717
750
self .add_import (c_imports ['math' ])
718
- return "fmin({}, {})" .format (self ._print (arg [0 ]),
719
- self ._print (arg [1 ]))
751
+ arg1 = self ._print (arg [0 ])
752
+ arg2 = self ._print (arg [1 ])
753
+ return f"f{ expr .name } ({ arg1 } , { arg2 } )"
720
754
elif arg .dtype .primitive_type is PrimitiveIntegerType () and len (arg ) == 2 :
721
755
if isinstance (arg [0 ], Variable ):
722
756
arg1 = self ._print (arg [0 ])
@@ -734,38 +768,22 @@ def _print_PythonMin(self, expr):
734
768
self ._additional_code += self ._print (assign2 )
735
769
arg2 = self ._print (arg2_temp )
736
770
737
- return f"({ arg1 } < { arg2 } ? { arg1 } : { arg2 } )"
771
+ op = '<' if isinstance (expr , PythonMin ) else '>'
772
+ return f"({ arg1 } { op } { arg2 } ? { arg1 } : { arg2 } )"
773
+ elif len (arg ) > 2 and isinstance (arg .dtype .primitive_type , (PrimitiveFloatingPointType , PrimitiveIntegerType )):
774
+ key = self .get_declare_type (arg [0 ])
775
+ self .add_import (Import ('stc/common' , AsName (VariableTypeAnnotation (arg .dtype ), key )))
776
+ self .add_import (Import ('Common_extensions' , AsName (VariableTypeAnnotation (arg .dtype ), key )))
777
+ return f'{ key } _{ expr .name } ({ len (arg )} , { ", " .join (self ._print (a ) for a in arg )} )'
738
778
else :
739
- return errors .report ("min in C is only supported for 2 scalar arguments " , symbol = expr ,
779
+ return errors .report (f" { expr . name } in C does not support arguments of type { arg . dtype } " , symbol = expr ,
740
780
severity = 'fatal' )
741
781
742
- def _print_PythonMax (self , expr ):
743
- arg = expr .args [0 ]
744
- if arg .dtype .primitive_type is PrimitiveFloatingPointType () and len (arg ) == 2 :
745
- self .add_import (c_imports ['math' ])
746
- return "fmax({}, {})" .format (self ._print (arg [0 ]),
747
- self ._print (arg [1 ]))
748
- elif arg .dtype .primitive_type is PrimitiveIntegerType () and len (arg ) == 2 :
749
- if isinstance (arg [0 ], Variable ):
750
- arg1 = self ._print (arg [0 ])
751
- else :
752
- arg1_temp = self .scope .get_temporary_variable (PythonNativeInt ())
753
- assign1 = Assign (arg1_temp , arg [0 ])
754
- self ._additional_code += self ._print (assign1 )
755
- arg1 = self ._print (arg1_temp )
756
-
757
- if isinstance (arg [1 ], Variable ):
758
- arg2 = self ._print (arg [1 ])
759
- else :
760
- arg2_temp = self .scope .get_temporary_variable (PythonNativeInt ())
761
- assign2 = Assign (arg2_temp , arg [1 ])
762
- self ._additional_code += self ._print (assign2 )
763
- arg2 = self ._print (arg2_temp )
782
+ def _print_PythonMin (self , expr ):
783
+ return self ._print_PythonMinMax (expr )
764
784
765
- return f"({ arg1 } > { arg2 } ? { arg1 } : { arg2 } )"
766
- else :
767
- return errors .report ("max in C is only supported for 2 scalar arguments" , symbol = expr ,
768
- severity = 'fatal' )
785
+ def _print_PythonMax (self , expr ):
786
+ return self ._print_PythonMinMax (expr )
769
787
770
788
def _print_SysExit (self , expr ):
771
789
code = ""
@@ -857,6 +875,7 @@ def _print_ModuleHeader(self, expr):
857
875
858
876
# Print imports last to be sure that all additional_imports have been collected
859
877
imports = [* expr .module .imports , * self ._additional_imports .values ()]
878
+ self .invalidate_stc_headers (imports )
860
879
imports = '' .join (self ._print (i ) for i in imports )
861
880
862
881
self ._in_header = False
@@ -1033,7 +1052,20 @@ def _print_Import(self, expr):
1033
1052
source = source .name [- 1 ].python_value
1034
1053
else :
1035
1054
source = self ._print (source )
1036
- if source .startswith ('stc/' ) or source in import_header_guard_prefix :
1055
+ if source == 'Common_extensions' :
1056
+ code = ''
1057
+ for t in expr .target :
1058
+ element_decl = f'#define i_key { t .local_alias } \n '
1059
+ header_guard_prefix = import_header_guard_prefix .get (source , '' )
1060
+ header_guard = f'{ header_guard_prefix } _{ t .local_alias .upper ()} '
1061
+ code += '' .join ((f'#ifndef { header_guard } \n ' ,
1062
+ f'#define { header_guard } \n ' ,
1063
+ element_decl ,
1064
+ f'#include <{ stc_header_mapping [source ]} .h>\n ' ,
1065
+ f'#include <{ source } .h>\n ' ,
1066
+ f'#endif // { header_guard } \n \n ' ))
1067
+ return code
1068
+ elif source .startswith ('stc/' ) or source in import_header_guard_prefix :
1037
1069
code = ''
1038
1070
for t in expr .target :
1039
1071
class_type = t .object .class_type
@@ -1055,6 +1087,8 @@ def _print_Import(self, expr):
1055
1087
f'#define { header_guard } \n ' ,
1056
1088
f'#define i_type { container_type } \n ' ,
1057
1089
element_decl ,
1090
+ '#define i_more\n ' if source in import_header_guard_prefix else '' ,
1091
+ f'#include <{ stc_header_mapping [source ]} .h>\n ' if source in import_header_guard_prefix else '' ,
1058
1092
f'#include <{ source } .h>\n ' ,
1059
1093
f'#endif // { header_guard } \n \n ' ))
1060
1094
return code
@@ -2630,6 +2664,7 @@ def _print_Program(self, expr):
2630
2664
decs = '' .join (self ._print (Declare (v )) for v in variables )
2631
2665
2632
2666
imports = [* expr .imports , * self ._additional_imports .values ()]
2667
+ self .invalidate_stc_headers (imports )
2633
2668
imports = '' .join (self ._print (i ) for i in imports )
2634
2669
2635
2670
self .exit_scope ()
0 commit comments