Skip to content

Commit a0ee291

Browse files
Add support for set method intersection pyccel#1744 (pyccel#2072)
Add support for set method intersection. Fixes pyccel#1744 **Commit Summary** - Add a class to represent `SetIntersection - Add intersection method to class description - add tests for set intersection method - update changelog and docs --------- Co-authored-by: Emily Bourne <[email protected]>
1 parent 0341065 commit a0ee291

File tree

6 files changed

+141
-2
lines changed

6 files changed

+141
-2
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ All notable changes to this project will be documented in this file.
4444
- #1918 : Add support for set method `clear()`.
4545
- #1918 : Add support for set method `copy()`.
4646
- #1753 : Add support for set method `union()`.
47+
- #1744 : Add Python support for set method `intersection()`.
4748
- #1884 : Add support for dict method `items()`.
4849
- #1936 : Add missing C output for inline decorator example in documentation
4950
- #1937 : Optimise `pyccel.ast.basic.PyccelAstNode.substitute` method.

docs/builtin-functions.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ Python contains a limited number of builtin functions defined [here](https://doc
102102
| `difference` | No |
103103
| `difference_update` | No |
104104
| `discard` | Python-only |
105-
| `intersection` | No |
105+
| `intersection` | Python-only |
106106
| `intersection_update` | No |
107107
| `isdisjoint` | No |
108108
| `issubset` | No |

pyccel/ast/builtin_methods/set_methods.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
'SetClear',
1919
'SetCopy',
2020
'SetDiscard',
21+
'SetIntersection',
2122
'SetMethod',
2223
'SetPop',
2324
'SetUnion',
@@ -232,3 +233,32 @@ def __init__(self, set_obj, *others):
232233
raise TypeError(f"Argument of type {o.class_type} cannot be used to build set of type {self._class_type}")
233234
self._shape = (None,)*self._class_type.rank
234235
super().__init__(set_obj, *others)
236+
237+
#==============================================================================
238+
239+
class SetIntersection(SetMethod):
240+
"""
241+
Represents a call to the set method .intersection.
242+
243+
Represents a call to the set method .intersection. This method builds a new set
244+
by including all elements which appear in "both" of the iterables
245+
(the set object and the arguments).
246+
247+
Parameters
248+
----------
249+
set_obj : TypedAstNode
250+
The set object which the method is called from.
251+
*others : TypedAstNode
252+
The iterables which will be combined (common elements) with this set.
253+
"""
254+
__slots__ = ('_other','_class_type', '_shape')
255+
name = 'intersection'
256+
257+
def __init__(self, set_obj, *others):
258+
self._class_type = set_obj.class_type
259+
element_type = self._class_type.element_type
260+
for o in others:
261+
if element_type != o.class_type.element_type:
262+
raise TypeError(f"Argument fo type {o.type_class} cannot be used to build set of type {self._class_type}")
263+
self._shape = (None,)*self.rank
264+
super().__init__(set_obj, *others)

pyccel/ast/class_defs.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
"""
88

99
from pyccel.ast.builtin_methods.set_methods import (SetAdd, SetClear, SetCopy, SetPop,
10-
SetDiscard, SetUpdate, SetUnion)
10+
SetDiscard, SetUpdate, SetUnion, SetIntersection)
1111
from pyccel.ast.builtin_methods.list_methods import (ListAppend, ListInsert, ListPop,
1212
ListClear, ListExtend, ListRemove,
1313
ListCopy, ListSort)
@@ -166,8 +166,10 @@
166166
PyccelFunctionDef('pop', func_class = SetPop),
167167
PyccelFunctionDef('remove', func_class = SetDiscard),
168168
PyccelFunctionDef('union', func_class = SetUnion),
169+
PyccelFunctionDef('intersection', func_class = SetIntersection),
169170
PyccelFunctionDef('update', func_class = SetUpdate),
170171
PyccelFunctionDef('__or__', func_class = SetUnion),
172+
PyccelFunctionDef('__and__', func_class = SetIntersection),
171173
PyccelFunctionDef('__ior__', func_class = SetUpdate),
172174
])
173175

pyccel/parser/syntactic.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,8 @@ def _visit_AugAssign(self, stmt):
442442
return AugAssign(lhs, '%', rhs)
443443
elif isinstance(stmt.op, ast.BitOr):
444444
return AugAssign(lhs, '|', rhs)
445+
elif isinstance(stmt.op, ast.BitAnd):
446+
return AugAssign(lhs, '&', rhs)
445447
else:
446448
return errors.report(PYCCEL_RESTRICTION_TODO, symbol = stmt,
447449
severity='error')

tests/epyccel/test_epyccel_sets.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,46 @@ def union_int():
425425
assert python_result[0] == pyccel_result[0]
426426
assert set(python_result[1:]) == set(pyccel_result[1:])
427427

428+
def test_set_intersection_int(python_only_language):
429+
def intersection_int():
430+
a = {1,2,3}
431+
b = {2,3,4}
432+
c = a.intersection(b)
433+
return len(c), c.pop(), c.pop()
434+
435+
epyccel_func = epyccel(intersection_int, language = python_only_language)
436+
pyccel_result = epyccel_func()
437+
python_result = intersection_int()
438+
assert python_result[0] == pyccel_result[0]
439+
assert set(python_result[1:]) == set(pyccel_result[1:])
440+
441+
def test_set_intersection_no_args(python_only_language):
442+
def intersection_int():
443+
a = {1,2,3,4}
444+
c = a.intersection()
445+
a.add(5)
446+
return len(c), c.pop(), c.pop(), c.pop(), c.pop()
447+
448+
epyccel_func = epyccel(intersection_int, language = python_only_language)
449+
pyccel_result = epyccel_func()
450+
python_result = intersection_int()
451+
assert python_result[0] == pyccel_result[0]
452+
assert set(python_result[1:]) == set(pyccel_result[1:])
453+
454+
def test_set_intersection_2_args(python_only_language):
455+
def intersection_int():
456+
a = {1,2,3,4}
457+
b = {5,6,7,2,1,3}
458+
c = {7,6,10,4,2,3,1}
459+
d = a.intersection(b, c)
460+
return len(d), d.pop(), d.pop(), d.pop()
461+
462+
epyccel_func = epyccel(intersection_int, language = python_only_language)
463+
pyccel_result = epyccel_func()
464+
python_result = intersection_int()
465+
assert python_result[0] == pyccel_result[0]
466+
assert set(python_result[1:]) == set(pyccel_result[1:])
467+
428468
@pytest.mark.parametrize( 'language', (
429469
pytest.param("fortran", marks = pytest.mark.fortran),
430470
pytest.param("c", marks = [
@@ -485,6 +525,57 @@ def union_int():
485525
assert python_result[0] == pyccel_result[0]
486526
assert set(python_result[1:]) == set(pyccel_result[1:])
487527

528+
def test_temporary_set_intersection(python_only_language):
529+
def intersection_int():
530+
a = {1,2}
531+
b = {2}
532+
d = a.intersection(b).pop()
533+
return d
534+
535+
epyccel_func = epyccel(intersection_int, language = python_only_language)
536+
pyccel_result = epyccel_func()
537+
python_result = intersection_int()
538+
assert python_result == pyccel_result
539+
540+
def test_set_intersection_list(python_only_language):
541+
def intersection_list():
542+
a = {1.2, 2.3, 5.0}
543+
b = [1.2, 5.0, 4.0]
544+
d = a.intersection(b)
545+
return len(d), d.pop(), d.pop()
546+
547+
epyccel_func = epyccel(intersection_list, language = python_only_language)
548+
pyccel_result = epyccel_func()
549+
python_result = intersection_list()
550+
assert python_result[0] == pyccel_result[0]
551+
assert set(python_result[1:]) == set(pyccel_result[1:])
552+
553+
def test_set_intersection_tuple(python_only_language):
554+
def intersection_tuple():
555+
a = {True}
556+
b = (False, True)
557+
d = a.intersection(b)
558+
return len(d), d.pop()
559+
560+
epyccel_func = epyccel(intersection_tuple, language = python_only_language)
561+
pyccel_result = epyccel_func()
562+
python_result = intersection_tuple()
563+
assert python_result[0] == pyccel_result[0]
564+
assert set(python_result[1:]) == set(pyccel_result[1:])
565+
566+
def test_set_intersection_operator(python_only_language):
567+
def intersection_int():
568+
a = {1,2,3,4,8}
569+
b = {5,2,3,7,8}
570+
c = a & b
571+
return len(c), c.pop(), c.pop(), c.pop()
572+
573+
epyccel_func = epyccel(intersection_int, language = python_only_language)
574+
pyccel_result = epyccel_func()
575+
python_result = intersection_int()
576+
assert python_result[0] == pyccel_result[0]
577+
assert set(python_result[1:]) == set(pyccel_result[1:])
578+
488579
@pytest.mark.parametrize( 'language', (
489580
pytest.param("fortran", marks = [
490581
pytest.mark.xfail(reason="Update not fully implemented yet. See #2022"),
@@ -521,6 +612,19 @@ def union_int():
521612
python_result = union_int()
522613
assert python_result == pyccel_result
523614

615+
def test_set_intersection_augoperator(python_only_language):
616+
def intersection_int():
617+
a = {1,2,3,4}
618+
b = {2,3,4}
619+
a &= b
620+
return len(a), a.pop(), a.pop(), a.pop()
621+
622+
epyccel_func = epyccel(intersection_int, language = python_only_language)
623+
pyccel_result = epyccel_func()
624+
python_result = intersection_int()
625+
assert python_result[0] == pyccel_result[0]
626+
assert set(python_result[1:]) == set(pyccel_result[1:])
627+
524628
def test_set_ptr(language):
525629
def set_ptr():
526630
a = {1,2,3,4,5,6,7,8}

0 commit comments

Comments
 (0)