Skip to content

Commit 86f761d

Browse files
authored
Allow iterating over a set (pyccel#2030)
Add support for iterating over a set. Fixes pyccel#2023 **Commit Summary** - Add class type information for `zip` - Complete definition of `IteratorType` following documentation - Add printing support for iterating over a set - Add iterator tests for sets
1 parent c7d846c commit 86f761d

File tree

6 files changed

+145
-65
lines changed

6 files changed

+145
-65
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ All notable changes to this project will be documented in this file.
3535
- #1689 : Add C and Fortran support for list method `append()`.
3636
- #1876 : Add C support for indexing lists.
3737
- #1690 : Add C support for list method `pop()`.
38+
- #2023 : Add support for iterating over a `set`.
3839
- #1877 : Add C and Fortran Support for set method `pop()`.
3940
- #1917 : Add C and Fortran support for set method `add()`.
4041
- #1918 : Add support for set method `clear()`.

pyccel/ast/builtins.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1150,7 +1150,7 @@ class PythonZip(PyccelFunction):
11501150
*args : tuple of TypedAstNode
11511151
The arguments passed to the function.
11521152
"""
1153-
__slots__ = ('_length',)
1153+
__slots__ = ('_length', '_class_type')
11541154
name = 'zip'
11551155

11561156
def __init__(self, *args):
@@ -1168,6 +1168,7 @@ def __init__(self, *args):
11681168
self._length = min(lengths)
11691169
else:
11701170
self._length = self.args[0].shape[0]
1171+
self._class_type = InhomogeneousTupleType(*[a.class_type for a in args])
11711172

11721173
@property
11731174
def length(self):

pyccel/ast/low_level_tools.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,39 @@ def iterable_type(self):
4242
"""
4343
return self._iterable_type
4444

45+
def __str__(self):
46+
return f'Iter[{self._iterable_type}]'
47+
48+
@property
49+
def datatype(self):
50+
"""
51+
The datatype of the object.
52+
53+
The datatype of the object.
54+
"""
55+
return self
56+
57+
@property
58+
def rank(self):
59+
"""
60+
Number of dimensions of the object.
61+
62+
Number of dimensions of the object. If the object is a scalar then
63+
this is equal to 0.
64+
"""
65+
return 0
66+
67+
@property
68+
def order(self):
69+
"""
70+
The data layout ordering in memory.
71+
72+
Indicates whether the data is stored in row-major ('C') or column-major
73+
('F') format. This is only relevant if rank > 1. When it is not relevant
74+
this function returns None.
75+
"""
76+
return None
77+
4578
#------------------------------------------------------------------------------
4679
class PairType(PyccelType, metaclass=ArgumentSingleton):
4780
"""

pyccel/codegen/printing/ccode.py

Lines changed: 37 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@
4242
from pyccel.ast.literals import LiteralString, LiteralInteger, Literal
4343
from pyccel.ast.literals import Nil
4444

45+
from pyccel.ast.low_level_tools import IteratorType
46+
4547
from pyccel.ast.mathext import math_constants
4648

4749
from pyccel.ast.numpyext import NumpyFull, NumpyArray
@@ -2326,45 +2328,44 @@ def _print_AliasAssign(self, expr):
23262328
def _print_For(self, expr):
23272329
self.set_scope(expr.scope)
23282330

2329-
indices = expr.iterable.loop_counters
2330-
index = indices[0] if indices else expr.target
2331-
if expr.iterable.num_loop_counters_required:
2332-
self.scope.insert_variable(index)
2333-
2334-
target = index
2335-
iterable = expr.iterable.get_range()
2336-
2337-
if not isinstance(iterable, PythonRange):
2338-
# Only iterable currently supported is PythonRange
2339-
errors.report(PYCCEL_RESTRICTION_TODO, symbol=expr,
2340-
severity='fatal')
2341-
2342-
counter = self._print(target)
2343-
body = self._print(expr.body)
2344-
2345-
additional_assign = CodeBlock(expr.iterable.get_assigns(expr.target))
2346-
body = self._print(additional_assign) + body
2347-
2348-
start = self._print(iterable.start)
2349-
stop = self._print(iterable.stop )
2350-
step = self._print(iterable.step )
2331+
iterable = expr.iterable
2332+
iterable_type = iterable.iterable.class_type
2333+
indices = iterable.loop_counters
2334+
2335+
if isinstance(iterable_type, (DictType, HomogeneousSetType, HomogeneousListType)):
2336+
counter = Variable(IteratorType(iterable_type), indices[0].name)
2337+
c_type = self.get_c_type(iterable_type)
2338+
iterable_code = self._print(iterable.iterable)
2339+
for_code = f'c_foreach ({self._print(counter)}, {c_type}, {iterable_code})'
2340+
additional_assign = CodeBlock([Assign(expr.target, DottedVariable(VoidType(), 'ref',
2341+
memory_handling='alias', lhs = counter))])
2342+
else:
2343+
index = indices[0] if indices else expr.target
2344+
if iterable.num_loop_counters_required:
2345+
self.scope.insert_variable(index)
2346+
2347+
target = index
2348+
counter = self._print(target)
2349+
iterable = expr.iterable.get_range()
2350+
additional_assign = CodeBlock(expr.iterable.get_assigns(expr.target))
2351+
2352+
step = iterable.step
2353+
start_code = self._print(iterable.start)
2354+
stop_code = self._print(iterable.stop )
2355+
step_code = self._print(iterable.step )
2356+
2357+
# testing if the step is a value or an expression
2358+
if is_literal_integer(step):
2359+
op = '>' if int(step) < 0 else '<'
2360+
stop_condition = f'{counter} {op} {stop_code}'
2361+
else:
2362+
stop_condition = f'({step_code} > 0) ? ({counter} < {stop_code}) : ({counter} > {stop_code})'
2363+
for_code = f'for ({counter} = {start_code}; {stop_condition}; {counter} += {step_code})\n'
23512364

2352-
test_step = iterable.step
2353-
if isinstance(test_step, PyccelUnarySub):
2354-
test_step = iterable.step.args[0]
2365+
body = self._print(additional_assign) + self._print(expr.body)
23552366

23562367
self.exit_scope()
2357-
# testing if the step is a value or an expression
2358-
if isinstance(test_step, Literal):
2359-
op = '>' if isinstance(iterable.step, PyccelUnarySub) else '<'
2360-
return ('for ({counter} = {start}; {counter} {op} {stop}; {counter} += '
2361-
'{step})\n{{\n{body}}}\n').format(counter=counter, start=start, op=op,
2362-
stop=stop, step=step, body=body)
2363-
else:
2364-
return (
2365-
'for ({counter} = {start}; ({step} > 0) ? ({counter} < {stop}) : ({counter} > {stop}); {counter} += '
2366-
'{step})\n{{\n{body}}}\n').format(counter=counter, start=start,
2367-
stop=stop, step=step, body=body)
2368+
return for_code + '{\n' + body + '}\n'
23682369

23692370
def _print_FunctionalFor(self, expr):
23702371
loops = ''.join(self._print(i) for i in expr.loops)

pyccel/codegen/printing/fcode.py

Lines changed: 42 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1865,8 +1865,10 @@ def _print_Declare(self, expr):
18651865
rankstr = ''
18661866

18671867
# ... print datatype
1868-
if isinstance(expr_type, CustomDataType):
1869-
name = expr_type.name
1868+
if isinstance(expr_type, (CustomDataType, IteratorType, HomogeneousListType, HomogeneousSetType, DictType)):
1869+
name = self._print(expr_type)
1870+
if isinstance(expr_type, (HomogeneousContainerType, DictType)):
1871+
self.add_import(self._build_gFTL_module(expr_type))
18701872

18711873
if var.is_argument:
18721874
sig = 'class'
@@ -1894,10 +1896,6 @@ def _print_Declare(self, expr):
18941896
raise NotImplementedError("Fortran rank string undetermined")
18951897
rankstr = f'({rankstr})'
18961898

1897-
elif isinstance(expr_type, (HomogeneousListType, HomogeneousSetType, DictType)):
1898-
self.add_import(self._build_gFTL_module(expr_type))
1899-
typename = self._print(expr_type)
1900-
dtype_str = f'type({typename})'
19011899
elif isinstance(dtype, StringType):
19021900
dtype_str = self._print(dtype)
19031901

@@ -2231,6 +2229,9 @@ def _print_IteratorType(self, expr):
22312229
iterable_type = self._print(expr.iterable_type)
22322230
return f"{iterable_type}_Iterator"
22332231

2232+
def _print_CustomDataType(self, expr):
2233+
return expr.name
2234+
22342235
def _print_DataType(self, expr):
22352236
return self._print(expr.name)
22362237

@@ -2568,27 +2569,42 @@ def _print_FunctionalFor(self, expr):
25682569
def _print_For(self, expr):
25692570
self.set_scope(expr.scope)
25702571

2572+
iterable = expr.iterable
2573+
iterable_type = iterable.iterable.class_type
25712574
indices = expr.iterable.loop_counters
2572-
index = indices[0] if indices else expr.target
2573-
if expr.iterable.num_loop_counters_required:
2574-
self.scope.insert_variable(index)
25752575

2576-
target = index
2577-
my_range = expr.iterable.get_range()
2578-
2579-
if not isinstance(my_range, PythonRange):
2580-
# Only iterable currently supported is PythonRange
2581-
errors.report(PYCCEL_RESTRICTION_TODO, symbol=expr,
2582-
severity='fatal')
2583-
2584-
tar = self._print(target)
2585-
range_code = self._print(my_range)
2586-
2587-
prolog = 'do {0} = {1}\n'.format(tar, range_code)
2588-
epilog = 'end do\n'
2589-
2590-
additional_assign = CodeBlock(expr.iterable.get_assigns(expr.target))
2591-
prolog += self._print(additional_assign)
2576+
if isinstance(iterable_type, (DictType, HomogeneousSetType)):
2577+
if isinstance(iterable.iterable, Variable):
2578+
suggested_name = iterable.iterable.name + '_'
2579+
else:
2580+
suggested_name = ''
2581+
errors.report("Iterating over a temporary object. This may cause compilation issues or cause calculations to be carried out twice",
2582+
severity='warning', symbol=expr)
2583+
iterable = self._print(iterable.iterable)
2584+
iterator = self.scope.get_temporary_variable(IteratorType(iterable_type),
2585+
name = suggested_name + 'iter')
2586+
last = self.scope.get_temporary_variable(IteratorType(iterable_type),
2587+
name = suggested_name + 'last')
2588+
target = self._print(expr.target)
2589+
prolog = (f'{iterator} = {iterable} % begin()\n'
2590+
f'{last} = {iterable} % end()\n'
2591+
f'do while ({iterator} /= {last})\n'
2592+
f'{target} = {iterator} % of()\n')
2593+
epilog = f'call {iterator} % next()\nend do\n'
2594+
else:
2595+
index = indices[0] if indices else expr.target
2596+
if iterable.num_loop_counters_required:
2597+
self.scope.insert_variable(index)
2598+
2599+
my_range = iterable.get_range()
2600+
2601+
target = self._print(index)
2602+
range_code = self._print(my_range)
2603+
2604+
prolog = f'do {target} = {range_code}\n'
2605+
epilog = 'end do\n'
2606+
2607+
prolog += self._print(CodeBlock(iterable.get_assigns(expr.target)))
25922608

25932609
body = self._print(expr.body)
25942610

@@ -2598,9 +2614,7 @@ def _print_For(self, expr):
25982614

25992615
self.exit_scope()
26002616

2601-
return ('{prolog}'
2602-
'{body}'
2603-
'{epilog}').format(prolog=prolog, body=body, epilog=epilog)
2617+
return prolog + body + epilog
26042618

26052619
# .....................................................
26062620
# Print OpenMP AnnotatedComment

tests/epyccel/test_epyccel_sets.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -532,3 +532,33 @@ def set_ptr():
532532
pyccel_result = epyccel_func()
533533
python_result = set_ptr()
534534
assert python_result == pyccel_result
535+
536+
def test_set_iter(language):
537+
def set_sum_int():
538+
a = {1,2,3,4,5,6,7,8,9,12}
539+
sum_a = 0
540+
for ai in a:
541+
sum_a += ai
542+
return sum_a
543+
544+
epyccel_func = epyccel(set_sum_int, language = language)
545+
pyccel_result = epyccel_func()
546+
python_result = set_sum_int()
547+
assert python_result == pyccel_result
548+
assert isinstance(python_result, type(pyccel_result))
549+
550+
def test_set_iter_prod(language):
551+
def set_iter_prod():
552+
from itertools import product
553+
a = {1,2,3,4,5,6,7,8,9,12}
554+
b = {2.0, 4.0, 9.0, 2.5, 8.3}
555+
assemble = 0.0
556+
for ai, bi in product(a,b):
557+
assemble += ai*bi
558+
return assemble
559+
560+
epyccel_func = epyccel(set_iter_prod, language = language)
561+
pyccel_result = epyccel_func()
562+
python_result = set_iter_prod()
563+
assert python_result == pyccel_result
564+
assert isinstance(python_result, type(pyccel_result))

0 commit comments

Comments
 (0)