Skip to content

Commit 86f322a

Browse files
authored
Add support for in operator (pyccel#2017)
Add support for `in` operator for `list`, `set` and `dict` containers for C Fortran and Python, as well as for classes. **Commit Summary** - Add `PyccelIn` class to describe the `in` operator - Add C support for `in` operator for list/set/dict - Add `#define i_use_cmp` to list declaration in C to activate sorting and searching functions - Add `_define_gFTL_element` function in Fortran code printer to define types and operators for types inside container types - Reduce duplication in `_build_gFTL_module` - Add Fortran support for `in` operator for list/set/dict - Add Python support for `in` operator for list/set/dict - In semantic stage handle `PyccelIn` for built-in types and class `__contains__` magic method - Add tests for `in` operator for lists/sets/dicts - Add tests for `in` operator for classes - Use `python_only_language` to allow some of the dictionary tests to be activated for other languages
1 parent b378f03 commit 86f322a

File tree

12 files changed

+243
-101
lines changed

12 files changed

+243
-101
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ All notable changes to this project will be documented in this file.
2727
- #1657 : Add the appropriate Fortran language equivalent for declaring a Python `list` container using the gFTL library.
2828
- #1658 : Add the appropriate Fortran language equivalent for declaring a Python `set` container using the gFTL library.
2929
- #1944 : Add the appropriate Fortran language equivalent for declaring a Python `dict` container using the gFTL library.
30+
- #2009 : Add support for `in` operator for `list`, `set`, `dict` and class containers.
3031
- #1874 : Add C and Fortran support for the `len()` function for the `list` container.
3132
- #1875 : Add C and Fortran support for the `len()` function for the `set` container.
3233
- #1908 : Add C and Fortran support for the `len()` function for the `dict` container.

pyccel/ast/operators.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
'Relational',
5757
'PyccelIs',
5858
'PyccelIsNot',
59+
'PyccelIn',
5960
'IfTernaryOperator'
6061
)
6162

@@ -1285,6 +1286,46 @@ def eval(self):
12851286

12861287
#==============================================================================
12871288

1289+
class PyccelIn(PyccelBooleanOperator):
1290+
"""
1291+
Represents an `in` expression in the code.
1292+
1293+
Represents an `in` expression in the code.
1294+
1295+
Parameters
1296+
----------
1297+
element : TypedAstNode
1298+
The first argument passed to the operator.
1299+
1300+
container : TypedAstNode
1301+
The first argument passed to the operator.
1302+
"""
1303+
__slots__ = ()
1304+
_precedence = 7
1305+
1306+
def __init__(self, element, container):
1307+
super().__init__(element, container)
1308+
1309+
@property
1310+
def element(self):
1311+
"""
1312+
First operator argument.
1313+
1314+
First operator argument.
1315+
"""
1316+
return self._args[0]
1317+
1318+
@property
1319+
def container(self):
1320+
"""
1321+
Second operator argument.
1322+
1323+
Second operator argument.
1324+
"""
1325+
return self._args[1]
1326+
1327+
#==============================================================================
1328+
12881329
class IfTernaryOperator(PyccelOperator):
12891330
"""
12901331
Represent a ternary conditional operator in the code.

pyccel/codegen/printing/ccode.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -972,6 +972,20 @@ def _print_PyccelNot(self, expr):
972972
a = self._print(expr.args[0])
973973
return '!{}'.format(a)
974974

975+
def _print_PyccelIn(self, expr):
976+
container_type = expr.container.class_type
977+
element = self._print(expr.element)
978+
container = self._print(ObjectAddress(expr.container))
979+
c_type = self.get_c_type(expr.container.class_type)
980+
if isinstance(container_type, (HomogeneousSetType, DictType)):
981+
return f'{c_type}_contains({container}, {element})'
982+
elif isinstance(container_type, HomogeneousListType):
983+
return f'{c_type}_find({container}, {element}).ref != {c_type}_end({container}).ref'
984+
else:
985+
raise errors.report(PYCCEL_RESTRICTION_TODO,
986+
symbol = expr,
987+
severity='fatal')
988+
975989
def _print_PyccelMod(self, expr):
976990
self.add_import(c_imports['math'])
977991
self.add_import(c_imports['pyc_math_c'])
@@ -1018,16 +1032,19 @@ def _print_Import(self, expr):
10181032
if source.startswith('stc/') or source in import_header_guard_prefix:
10191033
code = ''
10201034
for t in expr.target:
1021-
dtype = t.object.class_type
1035+
class_type = t.object.class_type
10221036
container_type = t.local_alias
1023-
if isinstance(dtype, DictType):
1024-
container_key_key = self.get_c_type(dtype.key_type)
1025-
container_val_key = self.get_c_type(dtype.value_type)
1037+
if isinstance(class_type, DictType):
1038+
container_key_key = self.get_c_type(class_type.key_type)
1039+
container_val_key = self.get_c_type(class_type.value_type)
10261040
container_key = f'{container_key_key}_{container_val_key}'
10271041
element_decl = f'#define i_key {container_key_key}\n#define i_val {container_val_key}\n'
10281042
else:
1029-
container_key = self.get_c_type(dtype.element_type)
1043+
container_key = self.get_c_type(class_type.element_type)
10301044
element_decl = f'#define i_key {container_key}\n'
1045+
if isinstance(class_type, HomogeneousListType) and isinstance(class_type.element_type, FixedSizeNumericType) \
1046+
and not isinstance(class_type.element_type.primitive_type, PrimitiveComplexType):
1047+
element_decl += '#define i_use_cmp\n'
10311048
header_guard_prefix = import_header_guard_prefix.get(source, '')
10321049
header_guard = f'{header_guard_prefix}_{container_type.upper()}'
10331050
code += ''.join((f'#ifndef {header_guard}\n',

pyccel/codegen/printing/fcode.py

Lines changed: 76 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@
6969

7070
from pyccel.ast.numpytypes import NumpyNDArrayType
7171

72-
from pyccel.ast.operators import PyccelAdd, PyccelMul, PyccelMinus, PyccelAnd
72+
from pyccel.ast.operators import PyccelAdd, PyccelMul, PyccelMinus, PyccelAnd, PyccelEq
7373
from pyccel.ast.operators import PyccelMod, PyccelNot, PyccelAssociativeParenthesis
7474
from pyccel.ast.operators import PyccelUnarySub, PyccelLt, PyccelGt, IfTernaryOperator
7575

@@ -524,12 +524,12 @@ def _calculate_class_names(self, expr):
524524
scope.rename_function(f, suggested_name)
525525
f.cls_name = scope.get_new_name(f'{name}_{f.name}')
526526

527-
def _get_comparison_operator(self, element_type, imports_and_macros):
527+
def _define_gFTL_element(self, element_type, imports_and_macros, element_name):
528528
"""
529-
Get an AST node describing a comparison operator between two objects of the same type.
529+
Get lists of nodes describing comparison operators between two objects of the same type.
530530
531-
Get an AST node describing a comparison operator between two objects of the same type.
532-
This is useful when defining gFTL modules.
531+
Get lists of nodes describing a comparison operators between two objects of the same type.
532+
This is necessary when defining gFTL modules.
533533
534534
Parameters
535535
----------
@@ -539,25 +539,48 @@ def _get_comparison_operator(self, element_type, imports_and_macros):
539539
imports_and_macros : list
540540
A list of imports or macros for the gFTL module.
541541
542+
element_name : str
543+
The name of the element whose properties are being specified.
544+
542545
Returns
543546
-------
544-
PyccelAstNode
545-
An AST node describing a comparison operator.
547+
defs : list[MacroDefinition]
548+
A list of nodes describing the macros defining comparison operator.
549+
undefs : list[MacroUndef]
550+
A list of nodes which undefine the macros.
546551
"""
547-
tmpVar_x = Variable(element_type, 'x')
548-
tmpVar_y = Variable(element_type, 'y')
549-
if isinstance(element_type.primitive_type, PrimitiveComplexType):
550-
complex_tool_import = Import('pyc_tools_f90', Module('pyc_tools_f90',(),()))
551-
self.add_import(complex_tool_import)
552-
imports_and_macros.append(complex_tool_import)
553-
compare_func = FunctionDef('complex_comparison',
554-
[FunctionDefArgument(tmpVar_x), FunctionDefArgument(tmpVar_y)],
555-
[FunctionDefResult(Variable(PythonNativeBool(), 'c'))], [])
556-
lt_def = compare_func(tmpVar_x, tmpVar_y)
552+
if isinstance(element_type, FixedSizeNumericType):
553+
tmpVar_x = Variable(element_type, 'x')
554+
tmpVar_y = Variable(element_type, 'y')
555+
if isinstance(element_type.primitive_type, PrimitiveComplexType):
556+
complex_tool_import = Import('pyc_tools_f90', Module('pyc_tools_f90',(),()))
557+
self.add_import(complex_tool_import)
558+
imports_and_macros.append(complex_tool_import)
559+
compare_func = FunctionDef('complex_comparison',
560+
[FunctionDefArgument(tmpVar_x), FunctionDefArgument(tmpVar_y)],
561+
[FunctionDefResult(Variable(PythonNativeBool(), 'c'))], [])
562+
lt_def = compare_func(tmpVar_x, tmpVar_y)
563+
else:
564+
lt_def = PyccelAssociativeParenthesis(PyccelLt(tmpVar_x, tmpVar_y))
565+
566+
defs = [MacroDefinition(element_name, element_type.primitive_type),
567+
MacroDefinition(f'{element_name}_KINDLEN(context)', KindSpecification(element_type)),
568+
MacroDefinition(f'{element_name}_LT(x,y)', lt_def),
569+
MacroDefinition(f'{element_name}_EQ(x,y)', PyccelAssociativeParenthesis(PyccelEq(tmpVar_x, tmpVar_y)))]
570+
undefs = [MacroUndef(element_name),
571+
MacroUndef(f'{element_name}_KINDLEN'),
572+
MacroUndef(f'{element_name}_LT'),
573+
MacroUndef(f'{element_name}_EQ')]
557574
else:
558-
lt_def = PyccelAssociativeParenthesis(PyccelLt(tmpVar_x, tmpVar_y))
575+
defs = [MacroDefinition(element_name, element_type)]
576+
undefs = [MacroUndef(element_name)]
559577

560-
return lt_def
578+
if isinstance(element_type, (NumpyNDArrayType, HomogeneousTupleType)):
579+
defs.append(MacroDefinition(f'{element_name}_rank', element_type.rank))
580+
undefs.append(MacroUndef(f'{element_name}_rank'))
581+
elif not isinstance(element_type, FixedSizeNumericType):
582+
raise NotImplementedError("Support for containers of types defined in other modules is not yet implemented")
583+
return defs, undefs
561584

562585
def _build_gFTL_module(self, expr_type):
563586
"""
@@ -583,60 +606,40 @@ def _build_gFTL_module(self, expr_type):
583606
module = self._generated_gFTL_types[matching_expr_type]
584607
mod_name = module.name
585608
else:
586-
if isinstance(expr_type, HomogeneousListType):
609+
if isinstance(expr_type, (HomogeneousListType, HomogeneousSetType)):
587610
element_type = expr_type.element_type
588-
if isinstance(element_type, FixedSizeNumericType):
589-
imports_and_macros = [MacroDefinition('T', element_type.primitive_type),
590-
MacroDefinition('T_KINDLEN(context)', KindSpecification(element_type))]
611+
if isinstance(expr_type, HomogeneousSetType):
612+
type_name = 'Set'
613+
if not isinstance(element_type, FixedSizeNumericType):
614+
raise NotImplementedError("Support for sets of types which define their own < operator is not yet implemented")
591615
else:
592-
imports_and_macros = [MacroDefinition('T', element_type)]
593-
if isinstance(element_type, (NumpyNDArrayType, HomogeneousTupleType)):
594-
imports_and_macros.append(MacroDefinition('T_rank', element_type.rank))
595-
elif not isinstance(element_type, FixedSizeNumericType):
596-
raise NotImplementedError("Support for lists of types defined in other modules is not yet implemented")
597-
imports_and_macros.extend([MacroDefinition('Vector', expr_type),
598-
MacroDefinition('VectorIterator', IteratorType(expr_type)),
599-
Import(LiteralString('vector/template.inc'), Module('_', (), ())),
600-
MacroUndef('Vector'),
601-
MacroUndef('VectorIterator'),])
602-
elif isinstance(expr_type, HomogeneousSetType):
603-
element_type = expr_type.element_type
616+
type_name = 'Vector'
604617
imports_and_macros = []
605-
if isinstance(element_type, FixedSizeNumericType):
606-
lt_def = self._get_comparison_operator(element_type, imports_and_macros)
607-
imports_and_macros.extend([MacroDefinition('T', element_type.primitive_type),
608-
MacroDefinition('T_KINDLEN(context)', KindSpecification(element_type)),
609-
MacroDefinition('T_LT(x,y)', lt_def)])
610-
else:
611-
raise NotImplementedError("Support for sets of types which define their own < operator is not yet implemented")
612-
imports_and_macros.extend([MacroDefinition('Set', expr_type),
613-
MacroDefinition('SetIterator', IteratorType(expr_type)),
614-
Import(LiteralString('set/template.inc'), Module('_', (), ())),
615-
MacroUndef('Set'),
616-
MacroUndef('SetIterator')])
618+
defs, undefs = self._define_gFTL_element(element_type, imports_and_macros, 'T')
619+
imports_and_macros.extend([*defs,
620+
MacroDefinition(type_name, expr_type),
621+
MacroDefinition(f'{type_name}Iterator', IteratorType(expr_type)),
622+
Import(LiteralString(f'{type_name.lower()}/template.inc'), Module('_', (), ())),
623+
MacroUndef(type_name),
624+
MacroUndef(f'{type_name}Iterator'),
625+
*undefs])
617626
elif isinstance(expr_type, DictType):
618627
key_type = expr_type.key_type
619628
value_type = expr_type.value_type
620629
imports_and_macros = []
621-
if isinstance(key_type, FixedSizeNumericType):
622-
lt_def = self._get_comparison_operator(key_type, imports_and_macros)
623-
imports_and_macros.extend([MacroDefinition('Key', key_type.primitive_type),
624-
MacroDefinition('Key_KINDLEN(context)', KindSpecification(key_type)),
625-
MacroDefinition('Key_LT(x,y)', lt_def)])
626-
else:
630+
if not isinstance(key_type, FixedSizeNumericType):
627631
raise NotImplementedError("Support for dicts whose keys define their own < operator is not yet implemented")
628-
if isinstance(value_type, FixedSizeNumericType):
629-
imports_and_macros.extend([MacroDefinition('T', value_type.primitive_type),
630-
MacroDefinition('T_KINDLEN(context)', KindSpecification(value_type))])
631-
else:
632-
raise NotImplementedError(f"Support for dictionary values of type {value_type} not yet implemented")
633-
imports_and_macros.extend([MacroDefinition('Pair', PairType(key_type, value_type)),
632+
key_defs, key_undefs = self._define_gFTL_element(key_type, imports_and_macros, 'Key')
633+
val_defs, val_undefs = self._define_gFTL_element(value_type, imports_and_macros, 'T')
634+
imports_and_macros.extend([*key_defs, *val_defs,
635+
MacroDefinition('Pair', PairType(key_type, value_type)),
634636
MacroDefinition('Map', expr_type),
635637
MacroDefinition('MapIterator', IteratorType(expr_type)),
636638
Import(LiteralString('map/template.inc'), Module('_', (), ())),
637639
MacroUndef('Pair'),
638640
MacroUndef('Map'),
639-
MacroUndef('MapIterator')])
641+
MacroUndef('MapIterator'),
642+
*key_undefs, *val_undefs])
640643
else:
641644
raise NotImplementedError(f"Unkown gFTL import for type {expr_type}")
642645

@@ -3077,6 +3080,19 @@ def _print_PyccelNot(self, expr):
30773080
return '{} == 0'.format(a)
30783081
return '.not. {}'.format(a)
30793082

3083+
def _print_PyccelIn(self, expr):
3084+
container_type = expr.container.class_type
3085+
element = self._print(expr.element)
3086+
container = self._print(expr.container)
3087+
if isinstance(container_type, (HomogeneousSetType, DictType)):
3088+
return f'{container} % count({element}) /= 0'
3089+
elif isinstance(container_type, HomogeneousListType):
3090+
return f'{container} % get_index({element}) /= 0'
3091+
else:
3092+
raise errors.report(PYCCEL_RESTRICTION_TODO,
3093+
symbol = expr,
3094+
severity='fatal')
3095+
30803096
def _print_Header(self, expr):
30813097
return ''
30823098

pyccel/codegen/printing/pycode.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1091,6 +1091,11 @@ def _print_Duplicate(self, expr):
10911091
def _print_Concatenate(self, expr):
10921092
return ' + '.join([self._print(a) for a in expr.args])
10931093

1094+
def _print_PyccelIn(self, expr):
1095+
element = self._print(expr.element)
1096+
container = self._print(expr.container)
1097+
return f'{element} in {container}'
1098+
10941099
def _print_PyccelSymbol(self, expr):
10951100
return expr
10961101

pyccel/parser/semantic.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@
9898

9999
from pyccel.ast.operators import PyccelArithmeticOperator, PyccelIs, PyccelIsNot, IfTernaryOperator, PyccelUnarySub
100100
from pyccel.ast.operators import PyccelNot, PyccelAdd, PyccelMinus, PyccelMul, PyccelPow
101-
from pyccel.ast.operators import PyccelAssociativeParenthesis, PyccelDiv
101+
from pyccel.ast.operators import PyccelAssociativeParenthesis, PyccelDiv, PyccelIn
102102

103103
from pyccel.ast.sympy_helper import sympy_to_pyccel, pyccel_to_sympy
104104

@@ -2926,6 +2926,25 @@ def _visit_PyccelPow(self, expr):
29262926
else:
29272927
return PyccelPow(base, exponent)
29282928

2929+
def _visit_PyccelIn(self, expr):
2930+
element = self._visit(expr.element)
2931+
container = self._visit(expr.container)
2932+
container_type = container.class_type
2933+
if isinstance(container_type, (DictType, HomogeneousSetType, HomogeneousListType)):
2934+
element_type = container_type.key_type if isinstance(container_type, DictType) else container_type.element_type
2935+
if element.class_type == element_type:
2936+
return PyccelIn(element, container)
2937+
else:
2938+
return LiteralFalse()
2939+
2940+
container_base = self.scope.find(str(container_type), 'classes') or get_cls_base(container_type)
2941+
contains_method = container_base.get_method('__contains__', raise_error = isinstance(container_type, CustomDataType))
2942+
if contains_method:
2943+
return contains_method(container, element)
2944+
else:
2945+
raise errors.report(f"In operator is not yet implemented for type {container_type}",
2946+
severity='fatal', symbol=expr)
2947+
29292948
def _visit_Lambda(self, expr):
29302949
errors.report("Lambda functions are not currently supported",
29312950
symbol=expr, severity='fatal')

pyccel/parser/syntactic.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
from pyccel.ast.operators import PyccelEq, PyccelNe, PyccelLt, PyccelLe, PyccelGt, PyccelGe
5050
from pyccel.ast.operators import PyccelAnd, PyccelOr, PyccelNot, PyccelMinus
5151
from pyccel.ast.operators import PyccelUnary, PyccelUnarySub
52-
from pyccel.ast.operators import PyccelIs, PyccelIsNot
52+
from pyccel.ast.operators import PyccelIs, PyccelIsNot, PyccelIn
5353
from pyccel.ast.operators import IfTernaryOperator
5454
from pyccel.ast.numpyext import NumpyMatmul
5555

@@ -715,6 +715,8 @@ def _visit_Compare(self, stmt):
715715
return PyccelIs(first, second)
716716
if isinstance(op, ast.IsNot):
717717
return PyccelIsNot(first, second)
718+
if isinstance(op, ast.In):
719+
return PyccelIn(first, second)
718720

719721
return errors.report(PYCCEL_RESTRICTION_UNSUPPORTED_SYNTAX,
720722
symbol = stmt,

0 commit comments

Comments
 (0)