52
52
from pyccel .ast .literals import Nil
53
53
54
54
from pyccel .ast .low_level_tools import MacroDefinition , IteratorType , PairType
55
+ from pyccel .ast .low_level_tools import MacroUndef
55
56
56
57
from pyccel .ast .mathext import math_constants
57
58
@@ -265,6 +266,7 @@ def __init__(self, filename, prefix_module = None):
265
266
266
267
self .prefix_module = prefix_module
267
268
269
+ self ._generated_gFTL_types = {}
268
270
self ._generated_gFTL_extensions = {}
269
271
270
272
def print_constant_imports (self ):
@@ -522,6 +524,41 @@ def _calculate_class_names(self, expr):
522
524
scope .rename_function (f , suggested_name )
523
525
f .cls_name = scope .get_new_name (f'{ name } _{ f .name } ' )
524
526
527
+ def _get_comparison_operator (self , element_type , imports_and_macros ):
528
+ """
529
+ Get an AST node describing a comparison operator between two objects of the same type.
530
+
531
+ Get an AST node describing a comparison operator between two objects of the same type.
532
+ This is useful when defining gFTL modules.
533
+
534
+ Parameters
535
+ ----------
536
+ element_type : PyccelType
537
+ The data type to be compared.
538
+
539
+ imports_and_macros : list
540
+ A list of imports or macros for the gFTL module.
541
+
542
+ Returns
543
+ -------
544
+ PyccelAstNode
545
+ An AST node describing a comparison operator.
546
+ """
547
+ tmpVar_x = Variable (element_type , 'x' )
548
+ tmpVar_y = Variable (element_type , 'y' )
549
+ if isinstance (element_type .primitive_type , PrimitiveComplexType ):
550
+ complex_tool_import = Import ('pyc_tools_f90' , Module ('pyc_tools_f90' ,(),()))
551
+ self .add_import (complex_tool_import )
552
+ imports_and_macros .append (complex_tool_import )
553
+ compare_func = FunctionDef ('complex_comparison' ,
554
+ [FunctionDefArgument (tmpVar_x ), FunctionDefArgument (tmpVar_y )],
555
+ [FunctionDefResult (Variable (PythonNativeBool (), 'c' ))], [])
556
+ lt_def = FunctionCall (compare_func , [tmpVar_x , tmpVar_y ])
557
+ else :
558
+ lt_def = PyccelAssociativeParenthesis (PyccelLt (tmpVar_x , tmpVar_y ))
559
+
560
+ return lt_def
561
+
525
562
def _build_gFTL_module (self , expr_type ):
526
563
"""
527
564
Build the gFTL module to create container types.
@@ -541,13 +578,12 @@ def _build_gFTL_module(self, expr_type):
541
578
The import which allows the new type to be accessed.
542
579
"""
543
580
# Get the type used in the dict for compatible types (e.g. float vs float64)
544
- matching_expr_type = next ((t for t in self ._generated_gFTL_extensions if expr_type == t ), None )
581
+ matching_expr_type = next ((t for t in self ._generated_gFTL_types if expr_type == t ), None )
545
582
if matching_expr_type :
546
- module = self ._generated_gFTL_extensions [matching_expr_type ]
583
+ module = self ._generated_gFTL_types [matching_expr_type ]
547
584
mod_name = module .name
548
585
else :
549
586
if isinstance (expr_type , HomogeneousListType ):
550
- include = Import (LiteralString ('vector/template.inc' ), Module ('_' , (), ()))
551
587
element_type = expr_type .element_type
552
588
if isinstance (element_type , FixedSizeNumericType ):
553
589
imports_and_macros = [MacroDefinition ('T' , element_type .primitive_type ),
@@ -558,50 +594,32 @@ def _build_gFTL_module(self, expr_type):
558
594
imports_and_macros .append (MacroDefinition ('T_rank' , element_type .rank ))
559
595
elif not isinstance (element_type , FixedSizeNumericType ):
560
596
raise NotImplementedError ("Support for lists of types defined in other modules is not yet implemented" )
561
- imports_and_macros .append (MacroDefinition ('Vector' , expr_type ))
562
- imports_and_macros .append (MacroDefinition ('VectorIterator' , IteratorType (expr_type )))
597
+ imports_and_macros .extend ([MacroDefinition ('Vector' , expr_type ),
598
+ MacroDefinition ('VectorIterator' , IteratorType (expr_type )),
599
+ Import (LiteralString ('vector/template.inc' ), Module ('_' , (), ())),
600
+ MacroUndef ('Vector' ),
601
+ MacroUndef ('VectorIterator' ),])
563
602
elif isinstance (expr_type , HomogeneousSetType ):
564
- include = Import (LiteralString ('set/template.inc' ), Module ('_' , (), ()))
565
603
element_type = expr_type .element_type
566
604
imports_and_macros = []
567
605
if isinstance (element_type , FixedSizeNumericType ):
568
- tmpVar_x = Variable (element_type , 'x' )
569
- tmpVar_y = Variable (element_type , 'y' )
570
- if isinstance (element_type .primitive_type , PrimitiveComplexType ):
571
- complex_tool_import = Import ('pyc_tools_f90' , Module ('pyc_tools_f90' ,(),()))
572
- self .add_import (complex_tool_import )
573
- imports_and_macros .append (complex_tool_import )
574
- compare_func = FunctionDef ('complex_comparison' ,
575
- [FunctionDefArgument (tmpVar_x ), FunctionDefArgument (tmpVar_y )],
576
- [FunctionDefResult (Variable (PythonNativeBool (), 'c' ))], [])
577
- lt_def = FunctionCall (compare_func , [tmpVar_x , tmpVar_y ])
578
- else :
579
- lt_def = PyccelAssociativeParenthesis (PyccelLt (tmpVar_x , tmpVar_y ))
606
+ lt_def = self ._get_comparison_operator (element_type , imports_and_macros )
580
607
imports_and_macros .extend ([MacroDefinition ('T' , element_type .primitive_type ),
581
608
MacroDefinition ('T_KINDLEN(context)' , KindSpecification (element_type )),
582
609
MacroDefinition ('T_LT(x,y)' , lt_def )])
583
610
else :
584
611
raise NotImplementedError ("Support for sets of types which define their own < operator is not yet implemented" )
585
- imports_and_macros .append (MacroDefinition ('Set' , expr_type ))
586
- imports_and_macros .append (MacroDefinition ('SetIterator' , IteratorType (expr_type )))
612
+ imports_and_macros .extend ([MacroDefinition ('Set' , expr_type ),
613
+ MacroDefinition ('SetIterator' , IteratorType (expr_type )),
614
+ Import (LiteralString ('set/template.inc' ), Module ('_' , (), ())),
615
+ MacroUndef ('Set' ),
616
+ MacroUndef ('SetIterator' )])
587
617
elif isinstance (expr_type , DictType ):
588
- include = Import (LiteralString ('map/template.inc' ), Module ('_' , (), ()))
589
618
key_type = expr_type .key_type
590
619
value_type = expr_type .value_type
591
620
imports_and_macros = []
592
621
if isinstance (key_type , FixedSizeNumericType ):
593
- tmpVar_x = Variable (key_type , 'x' )
594
- tmpVar_y = Variable (key_type , 'y' )
595
- if isinstance (key_type .primitive_type , PrimitiveComplexType ):
596
- complex_tool_import = Import ('pyc_tools_f90' , Module ('pyc_tools_f90' ,(),()))
597
- self .add_import (complex_tool_import )
598
- imports_and_macros .append (complex_tool_import )
599
- compare_func = FunctionDef ('complex_comparison' ,
600
- [FunctionDefArgument (tmpVar_x ), FunctionDefArgument (tmpVar_y )],
601
- [FunctionDefResult (Variable (PythonNativeBool (), 'c' ))], [])
602
- lt_def = FunctionCall (compare_func , [tmpVar_x , tmpVar_y ])
603
- else :
604
- lt_def = PyccelAssociativeParenthesis (PyccelLt (tmpVar_x , tmpVar_y ))
622
+ lt_def = self ._get_comparison_operator (key_type , imports_and_macros )
605
623
imports_and_macros .extend ([MacroDefinition ('Key' , key_type .primitive_type ),
606
624
MacroDefinition ('Key_KINDLEN(context)' , KindSpecification (key_type )),
607
625
MacroDefinition ('Key_LT(x,y)' , lt_def )])
@@ -612,15 +630,68 @@ def _build_gFTL_module(self, expr_type):
612
630
MacroDefinition ('T_KINDLEN(context)' , KindSpecification (value_type ))])
613
631
else :
614
632
raise NotImplementedError (f"Support for dictionary values of type { value_type } not yet implemented" )
615
- imports_and_macros .append (MacroDefinition ('Pair' , PairType (key_type , value_type )))
616
- imports_and_macros .append (MacroDefinition ('Map' , expr_type ))
617
- imports_and_macros .append (MacroDefinition ('MapIterator' , IteratorType (expr_type )))
633
+ imports_and_macros .extend ([MacroDefinition ('Pair' , PairType (key_type , value_type )),
634
+ MacroDefinition ('Map' , expr_type ),
635
+ MacroDefinition ('MapIterator' , IteratorType (expr_type )),
636
+ Import (LiteralString ('map/template.inc' ), Module ('_' , (), ())),
637
+ MacroUndef ('Pair' ),
638
+ MacroUndef ('Map' ),
639
+ MacroUndef ('MapIterator' )])
618
640
else :
619
641
raise NotImplementedError (f"Unkown gFTL import for type { expr_type } " )
620
642
621
643
typename = self ._print (expr_type )
622
644
mod_name = f'{ typename } _mod'
623
- module = Module (mod_name , (), (), scope = Scope (), imports = [* imports_and_macros , include ],
645
+ module = Module (mod_name , (), (), scope = Scope (), imports = imports_and_macros ,
646
+ is_external = True )
647
+
648
+ self ._generated_gFTL_types [expr_type ] = module
649
+
650
+ return Import (f'gFTL_extensions/{ mod_name } ' , module )
651
+
652
+ def _build_gFTL_extension_module (self , expr_type ):
653
+ """
654
+ Build the gFTL module to create container extension functions.
655
+
656
+ Create a module which will import the gFTL include files
657
+ in order to create container types (e.g lists, sets, etc).
658
+ The name of the module is derived from the name of the type.
659
+
660
+ Parameters
661
+ ----------
662
+ expr_type : DataType
663
+ The data type for which extensions are required.
664
+
665
+ Returns
666
+ -------
667
+ Import
668
+ The import which allows the new type to be accessed.
669
+ """
670
+ # Get the type used in the dict for compatible types (e.g. float vs float64)
671
+ matching_expr_type = next ((t for t in self ._generated_gFTL_types if expr_type == t ), None )
672
+ matching_expr_extensions = next ((t for t in self ._generated_gFTL_extensions if expr_type == t ), None )
673
+ typename = self ._print (expr_type )
674
+ mod_name = f'{ typename } _extensions_mod'
675
+ if matching_expr_extensions :
676
+ module = self ._generated_gFTL_extensions [matching_expr_extensions ]
677
+ else :
678
+ if matching_expr_type is None :
679
+ matching_expr_type = self ._build_gFTL_module (expr_type )
680
+ self .add_import (matching_expr_type )
681
+
682
+ type_module = matching_expr_type .source_module
683
+
684
+ if isinstance (expr_type , HomogeneousSetType ):
685
+ set_filename = LiteralString ('set/template.inc' )
686
+ imports_and_macros = [Import (LiteralString ('Set_extensions.inc' ), Module ('_' , (), ())) \
687
+ if getattr (i , 'source' , None ) == set_filename else i \
688
+ for i in type_module .imports ]
689
+ imports_and_macros .insert (0 , matching_expr_type )
690
+ self .add_import (Import ('gFTL_functions/Set_extensions' , Module ('_' , (), ()), ignore_at_print = True ))
691
+ else :
692
+ raise NotImplementedError (f"Unkown gFTL import for type { expr_type } " )
693
+
694
+ module = Module (mod_name , (), (), scope = Scope (), imports = imports_and_macros ,
624
695
is_external = True )
625
696
626
697
self ._generated_gFTL_extensions [expr_type ] = module
@@ -1253,6 +1324,14 @@ def _print_SetClear(self, expr):
1253
1324
var = self ._print (expr .set_variable )
1254
1325
return f'call { var } % clear()\n '
1255
1326
1327
+ def _print_SetPop (self , expr ):
1328
+ var = expr .set_variable
1329
+ expr_type = var .class_type
1330
+ var_code = self ._print (expr .set_variable )
1331
+ type_name = self ._print (expr_type )
1332
+ self .add_import (self ._build_gFTL_extension_module (expr_type ))
1333
+ return f'{ type_name } _pop({ var_code } )\n '
1334
+
1256
1335
#========================== Numpy Elements ===============================#
1257
1336
1258
1337
def _print_NumpySum (self , expr ):
@@ -3596,5 +3675,9 @@ def _print_MacroDefinition(self, expr):
3596
3675
obj = self ._print (expr .object )
3597
3676
return f'#define { name } { obj } \n '
3598
3677
3678
+ def _print_MacroUndef (self , expr ):
3679
+ name = expr .macro_name
3680
+ return f'#undef { name } \n '
3681
+
3599
3682
def _print_KindSpecification (self , expr ):
3600
3683
return f'(kind = { self .print_kind (expr .type_specifier )} )'
0 commit comments