69
69
70
70
from pyccel .ast .numpytypes import NumpyNDArrayType
71
71
72
- from pyccel .ast .operators import PyccelAdd , PyccelMul , PyccelMinus , PyccelAnd
72
+ from pyccel .ast .operators import PyccelAdd , PyccelMul , PyccelMinus , PyccelAnd , PyccelEq
73
73
from pyccel .ast .operators import PyccelMod , PyccelNot , PyccelAssociativeParenthesis
74
74
from pyccel .ast .operators import PyccelUnarySub , PyccelLt , PyccelGt , IfTernaryOperator
75
75
@@ -524,12 +524,12 @@ def _calculate_class_names(self, expr):
524
524
scope .rename_function (f , suggested_name )
525
525
f .cls_name = scope .get_new_name (f'{ name } _{ f .name } ' )
526
526
527
- def _get_comparison_operator (self , element_type , imports_and_macros ):
527
+ def _define_gFTL_element (self , element_type , imports_and_macros , element_name ):
528
528
"""
529
- Get an AST node describing a comparison operator between two objects of the same type.
529
+ Get lists of nodes describing comparison operators between two objects of the same type.
530
530
531
- Get an AST node describing a comparison operator between two objects of the same type.
532
- This is useful when defining gFTL modules.
531
+ Get lists of nodes describing a comparison operators between two objects of the same type.
532
+ This is necessary when defining gFTL modules.
533
533
534
534
Parameters
535
535
----------
@@ -539,25 +539,48 @@ def _get_comparison_operator(self, element_type, imports_and_macros):
539
539
imports_and_macros : list
540
540
A list of imports or macros for the gFTL module.
541
541
542
+ element_name : str
543
+ The name of the element whose properties are being specified.
544
+
542
545
Returns
543
546
-------
544
- PyccelAstNode
545
- An AST node describing a comparison operator.
547
+ defs : list[MacroDefinition]
548
+ A list of nodes describing the macros defining comparison operator.
549
+ undefs : list[MacroUndef]
550
+ A list of nodes which undefine the macros.
546
551
"""
547
- tmpVar_x = Variable (element_type , 'x' )
548
- tmpVar_y = Variable (element_type , 'y' )
549
- if isinstance (element_type .primitive_type , PrimitiveComplexType ):
550
- complex_tool_import = Import ('pyc_tools_f90' , Module ('pyc_tools_f90' ,(),()))
551
- self .add_import (complex_tool_import )
552
- imports_and_macros .append (complex_tool_import )
553
- compare_func = FunctionDef ('complex_comparison' ,
554
- [FunctionDefArgument (tmpVar_x ), FunctionDefArgument (tmpVar_y )],
555
- [FunctionDefResult (Variable (PythonNativeBool (), 'c' ))], [])
556
- lt_def = compare_func (tmpVar_x , tmpVar_y )
552
+ if isinstance (element_type , FixedSizeNumericType ):
553
+ tmpVar_x = Variable (element_type , 'x' )
554
+ tmpVar_y = Variable (element_type , 'y' )
555
+ if isinstance (element_type .primitive_type , PrimitiveComplexType ):
556
+ complex_tool_import = Import ('pyc_tools_f90' , Module ('pyc_tools_f90' ,(),()))
557
+ self .add_import (complex_tool_import )
558
+ imports_and_macros .append (complex_tool_import )
559
+ compare_func = FunctionDef ('complex_comparison' ,
560
+ [FunctionDefArgument (tmpVar_x ), FunctionDefArgument (tmpVar_y )],
561
+ [FunctionDefResult (Variable (PythonNativeBool (), 'c' ))], [])
562
+ lt_def = compare_func (tmpVar_x , tmpVar_y )
563
+ else :
564
+ lt_def = PyccelAssociativeParenthesis (PyccelLt (tmpVar_x , tmpVar_y ))
565
+
566
+ defs = [MacroDefinition (element_name , element_type .primitive_type ),
567
+ MacroDefinition (f'{ element_name } _KINDLEN(context)' , KindSpecification (element_type )),
568
+ MacroDefinition (f'{ element_name } _LT(x,y)' , lt_def ),
569
+ MacroDefinition (f'{ element_name } _EQ(x,y)' , PyccelAssociativeParenthesis (PyccelEq (tmpVar_x , tmpVar_y )))]
570
+ undefs = [MacroUndef (element_name ),
571
+ MacroUndef (f'{ element_name } _KINDLEN' ),
572
+ MacroUndef (f'{ element_name } _LT' ),
573
+ MacroUndef (f'{ element_name } _EQ' )]
557
574
else :
558
- lt_def = PyccelAssociativeParenthesis (PyccelLt (tmpVar_x , tmpVar_y ))
575
+ defs = [MacroDefinition (element_name , element_type )]
576
+ undefs = [MacroUndef (element_name )]
559
577
560
- return lt_def
578
+ if isinstance (element_type , (NumpyNDArrayType , HomogeneousTupleType )):
579
+ defs .append (MacroDefinition (f'{ element_name } _rank' , element_type .rank ))
580
+ undefs .append (MacroUndef (f'{ element_name } _rank' ))
581
+ elif not isinstance (element_type , FixedSizeNumericType ):
582
+ raise NotImplementedError ("Support for containers of types defined in other modules is not yet implemented" )
583
+ return defs , undefs
561
584
562
585
def _build_gFTL_module (self , expr_type ):
563
586
"""
@@ -583,60 +606,40 @@ def _build_gFTL_module(self, expr_type):
583
606
module = self ._generated_gFTL_types [matching_expr_type ]
584
607
mod_name = module .name
585
608
else :
586
- if isinstance (expr_type , HomogeneousListType ):
609
+ if isinstance (expr_type , ( HomogeneousListType , HomogeneousSetType ) ):
587
610
element_type = expr_type .element_type
588
- if isinstance (element_type , FixedSizeNumericType ):
589
- imports_and_macros = [MacroDefinition ('T' , element_type .primitive_type ),
590
- MacroDefinition ('T_KINDLEN(context)' , KindSpecification (element_type ))]
611
+ if isinstance (expr_type , HomogeneousSetType ):
612
+ type_name = 'Set'
613
+ if not isinstance (element_type , FixedSizeNumericType ):
614
+ raise NotImplementedError ("Support for sets of types which define their own < operator is not yet implemented" )
591
615
else :
592
- imports_and_macros = [MacroDefinition ('T' , element_type )]
593
- if isinstance (element_type , (NumpyNDArrayType , HomogeneousTupleType )):
594
- imports_and_macros .append (MacroDefinition ('T_rank' , element_type .rank ))
595
- elif not isinstance (element_type , FixedSizeNumericType ):
596
- raise NotImplementedError ("Support for lists of types defined in other modules is not yet implemented" )
597
- imports_and_macros .extend ([MacroDefinition ('Vector' , expr_type ),
598
- MacroDefinition ('VectorIterator' , IteratorType (expr_type )),
599
- Import (LiteralString ('vector/template.inc' ), Module ('_' , (), ())),
600
- MacroUndef ('Vector' ),
601
- MacroUndef ('VectorIterator' ),])
602
- elif isinstance (expr_type , HomogeneousSetType ):
603
- element_type = expr_type .element_type
616
+ type_name = 'Vector'
604
617
imports_and_macros = []
605
- if isinstance (element_type , FixedSizeNumericType ):
606
- lt_def = self ._get_comparison_operator (element_type , imports_and_macros )
607
- imports_and_macros .extend ([MacroDefinition ('T' , element_type .primitive_type ),
608
- MacroDefinition ('T_KINDLEN(context)' , KindSpecification (element_type )),
609
- MacroDefinition ('T_LT(x,y)' , lt_def )])
610
- else :
611
- raise NotImplementedError ("Support for sets of types which define their own < operator is not yet implemented" )
612
- imports_and_macros .extend ([MacroDefinition ('Set' , expr_type ),
613
- MacroDefinition ('SetIterator' , IteratorType (expr_type )),
614
- Import (LiteralString ('set/template.inc' ), Module ('_' , (), ())),
615
- MacroUndef ('Set' ),
616
- MacroUndef ('SetIterator' )])
618
+ defs , undefs = self ._define_gFTL_element (element_type , imports_and_macros , 'T' )
619
+ imports_and_macros .extend ([* defs ,
620
+ MacroDefinition (type_name , expr_type ),
621
+ MacroDefinition (f'{ type_name } Iterator' , IteratorType (expr_type )),
622
+ Import (LiteralString (f'{ type_name .lower ()} /template.inc' ), Module ('_' , (), ())),
623
+ MacroUndef (type_name ),
624
+ MacroUndef (f'{ type_name } Iterator' ),
625
+ * undefs ])
617
626
elif isinstance (expr_type , DictType ):
618
627
key_type = expr_type .key_type
619
628
value_type = expr_type .value_type
620
629
imports_and_macros = []
621
- if isinstance (key_type , FixedSizeNumericType ):
622
- lt_def = self ._get_comparison_operator (key_type , imports_and_macros )
623
- imports_and_macros .extend ([MacroDefinition ('Key' , key_type .primitive_type ),
624
- MacroDefinition ('Key_KINDLEN(context)' , KindSpecification (key_type )),
625
- MacroDefinition ('Key_LT(x,y)' , lt_def )])
626
- else :
630
+ if not isinstance (key_type , FixedSizeNumericType ):
627
631
raise NotImplementedError ("Support for dicts whose keys define their own < operator is not yet implemented" )
628
- if isinstance (value_type , FixedSizeNumericType ):
629
- imports_and_macros .extend ([MacroDefinition ('T' , value_type .primitive_type ),
630
- MacroDefinition ('T_KINDLEN(context)' , KindSpecification (value_type ))])
631
- else :
632
- raise NotImplementedError (f"Support for dictionary values of type { value_type } not yet implemented" )
633
- imports_and_macros .extend ([MacroDefinition ('Pair' , PairType (key_type , value_type )),
632
+ key_defs , key_undefs = self ._define_gFTL_element (key_type , imports_and_macros , 'Key' )
633
+ val_defs , val_undefs = self ._define_gFTL_element (value_type , imports_and_macros , 'T' )
634
+ imports_and_macros .extend ([* key_defs , * val_defs ,
635
+ MacroDefinition ('Pair' , PairType (key_type , value_type )),
634
636
MacroDefinition ('Map' , expr_type ),
635
637
MacroDefinition ('MapIterator' , IteratorType (expr_type )),
636
638
Import (LiteralString ('map/template.inc' ), Module ('_' , (), ())),
637
639
MacroUndef ('Pair' ),
638
640
MacroUndef ('Map' ),
639
- MacroUndef ('MapIterator' )])
641
+ MacroUndef ('MapIterator' ),
642
+ * key_undefs , * val_undefs ])
640
643
else :
641
644
raise NotImplementedError (f"Unkown gFTL import for type { expr_type } " )
642
645
@@ -3077,6 +3080,19 @@ def _print_PyccelNot(self, expr):
3077
3080
return '{} == 0' .format (a )
3078
3081
return '.not. {}' .format (a )
3079
3082
3083
+ def _print_PyccelIn (self , expr ):
3084
+ container_type = expr .container .class_type
3085
+ element = self ._print (expr .element )
3086
+ container = self ._print (expr .container )
3087
+ if isinstance (container_type , (HomogeneousSetType , DictType )):
3088
+ return f'{ container } % count({ element } ) /= 0'
3089
+ elif isinstance (container_type , HomogeneousListType ):
3090
+ return f'{ container } % get_index({ element } ) /= 0'
3091
+ else :
3092
+ raise errors .report (PYCCEL_RESTRICTION_TODO ,
3093
+ symbol = expr ,
3094
+ severity = 'fatal' )
3095
+
3080
3096
def _print_Header (self , expr ):
3081
3097
return ''
3082
3098
0 commit comments