Skip to content

Commit b5abf6f

Browse files
authored
Add C and Fortran support for set.intersection and set.intersection_update (pyccel#2117)
Add C and Fortran support for set methods `intersection` and `intersection_update`. Fixes pyccel#1745 Map `set.intersection` to `set.copy` and `set.intersection_update`. Activate existing tests and add tests for `intersection_update`. Remove tests for `intersection` taking lists or tuples as arguments. This can be added back later if it is requested
1 parent eb9b7f6 commit b5abf6f

File tree

10 files changed

+164
-40
lines changed

10 files changed

+164
-40
lines changed

CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ All notable changes to this project will be documented in this file.
4646
- #1918 : Add support for set method `copy()`.
4747
- #1753 : Add support for set method `union()`.
4848
- #1754 : Add support for set method `update()`.
49-
- #1744 : Add Python support for set method `intersection()`.
49+
- #1744 : Add support for set method `intersection()`.
50+
- #1745 : Add support for set method `intersection_update()`.
5051
- #1884 : Add support for dict method `items()`.
5152
- #1936 : Add missing C output for inline decorator example in documentation
5253
- #1937 : Optimise `pyccel.ast.basic.PyccelAstNode.substitute` method.

docs/builtin-functions.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,8 @@ Python contains a limited number of builtin functions defined [here](https://doc
102102
| `difference` | No |
103103
| `difference_update` | No |
104104
| `discard` | Python-only |
105-
| `intersection` | Python-only |
106-
| `intersection_update` | No |
105+
| `intersection` | **Yes** |
106+
| `intersection_update` | **Yes** |
107107
| `isdisjoint` | No |
108108
| `issubset` | No |
109109
| `issuperset` | No |

pyccel/ast/builtin_methods/set_methods.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
'SetCopy',
2020
'SetDiscard',
2121
'SetIntersection',
22+
'SetIntersectionUpdate',
2223
'SetMethod',
2324
'SetPop',
2425
'SetUnion',
@@ -254,11 +255,30 @@ class SetIntersection(SetMethod):
254255
__slots__ = ('_other','_class_type', '_shape')
255256
name = 'intersection'
256257

258+
#==============================================================================
259+
260+
class SetIntersectionUpdate(SetMethod):
261+
"""
262+
Represents a call to the .intersection_update() method.
263+
264+
Represents a call to the set method .intersection_update(). This method combines
265+
two sets by including all elements which appear in all of the sets.
266+
267+
Parameters
268+
----------
269+
set_obj : TypedAstNode
270+
The set object which the method is called from.
271+
*others : TypedAstNode
272+
The sets which will be combined with this set.
273+
"""
274+
__slots__ = ()
275+
name = 'intersection_update'
276+
_class_type = VoidType()
277+
_shape = None
278+
257279
def __init__(self, set_obj, *others):
258-
self._class_type = set_obj.class_type
259-
element_type = self._class_type.element_type
280+
class_type = set_obj.class_type
260281
for o in others:
261-
if element_type != o.class_type.element_type:
262-
raise TypeError(f"Argument fo type {o.type_class} cannot be used to build set of type {self._class_type}")
263-
self._shape = (None,)*self.rank
282+
if class_type != o.class_type:
283+
raise TypeError(f"Only arguments of type {class_type} are supported for the functions intersection and .intersection_update")
264284
super().__init__(set_obj, *others)

pyccel/ast/class_defs.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
"""
88

99
from pyccel.ast.builtin_methods.set_methods import (SetAdd, SetClear, SetCopy, SetPop,
10-
SetDiscard, SetUpdate, SetUnion, SetIntersection)
10+
SetDiscard, SetUpdate, SetUnion,
11+
SetIntersection, SetIntersectionUpdate)
1112
from pyccel.ast.builtin_methods.list_methods import (ListAppend, ListInsert, ListPop,
1213
ListClear, ListExtend, ListRemove,
1314
ListCopy, ListSort)
@@ -167,9 +168,11 @@
167168
PyccelFunctionDef('remove', func_class = SetDiscard),
168169
PyccelFunctionDef('union', func_class = SetUnion),
169170
PyccelFunctionDef('intersection', func_class = SetIntersection),
171+
PyccelFunctionDef('intersection_update', func_class = SetIntersectionUpdate),
170172
PyccelFunctionDef('update', func_class = SetUpdate),
171173
PyccelFunctionDef('__or__', func_class = SetUnion),
172174
PyccelFunctionDef('__and__', func_class = SetIntersection),
175+
PyccelFunctionDef('__iand__', func_class = SetIntersectionUpdate),
173176
PyccelFunctionDef('__ior__', func_class = SetUpdate),
174177
])
175178

pyccel/codegen/printing/ccode.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2796,6 +2796,14 @@ def _print_SetUnion(self, expr):
27962796
args = ', '.join([str(len(expr.args)), *(self._print(ObjectAddress(a)) for a in expr.args)])
27972797
return f'{var_type}_union({set_var}, {args})'
27982798

2799+
def _print_SetIntersectionUpdate(self, expr):
2800+
class_type = expr.set_variable.class_type
2801+
var_type = self.get_c_type(class_type)
2802+
self.add_import(Import('Set_extensions', AsName(VariableTypeAnnotation(class_type), var_type)))
2803+
set_var = self._print(ObjectAddress(expr.set_variable))
2804+
return ''.join(f'{var_type}_intersection_update({set_var}, {self._print(ObjectAddress(a))});\n' \
2805+
for a in expr.args)
2806+
27992807
#=================== MACROS ==================
28002808

28012809
def _print_MacroShape(self, expr):

pyccel/codegen/printing/fcode.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1417,6 +1417,15 @@ def _print_SetUnion(self, expr):
14171417
self._additional_code += code
14181418
return result
14191419

1420+
def _print_SetIntersectionUpdate(self, expr):
1421+
var = expr.set_variable
1422+
expr_type = var.class_type
1423+
var_code = self._print(expr.set_variable)
1424+
type_name = self._print(expr_type)
1425+
self.add_import(self._build_gFTL_extension_module(expr_type))
1426+
return ''.join(f'call {type_name}_intersection_update({var_code}, {self._print(arg)})\n' \
1427+
for arg in expr.args)
1428+
14201429
#========================== Numpy Elements ===============================#
14211430

14221431
def _print_NumpySum(self, expr):

pyccel/parser/semantic.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
PythonTuple, Lambda, PythonMap)
3535

3636
from pyccel.ast.builtin_methods.list_methods import ListMethod, ListAppend
37-
from pyccel.ast.builtin_methods.set_methods import SetAdd, SetUnion
37+
from pyccel.ast.builtin_methods.set_methods import SetAdd, SetUnion, SetCopy, SetIntersectionUpdate
3838

3939
from pyccel.ast.core import Comment, CommentBlock, Pass
4040
from pyccel.ast.core import If, IfSection
@@ -5463,3 +5463,48 @@ def _build_SetUnion(self, expr, function_call_args):
54635463
for c in update_calls]
54645464
pyccel_stage.set_stage('semantic')
54655465
return CodeBlock([self._visit(b) for b in body])
5466+
5467+
def _build_SetIntersection(self, expr, function_call_args):
5468+
"""
5469+
Method to visit a SetIntersection node.
5470+
5471+
The purpose of this `_build` method is to construct multiple nodes to represent
5472+
the single DottedName node representing the call to SetIntersection. It
5473+
replaces the call with a call to copy followed by multiple calls to
5474+
SetIntersectionUpdate.
5475+
5476+
Parameters
5477+
----------
5478+
expr : DottedName
5479+
The syntactic DottedName node that represent the call to `.intersection()`.
5480+
5481+
function_call_args : iterable[FunctionCallArgument]
5482+
The semantic arguments passed to the function.
5483+
5484+
Returns
5485+
-------
5486+
CodeBlock
5487+
CodeBlock containing SetCopy and SetIntersectionUpdate objects.
5488+
"""
5489+
start_set = function_call_args[0].value
5490+
set_args = [self._visit(a.value) for a in function_call_args[1:]]
5491+
assign = expr.get_direct_user_nodes(lambda a: isinstance(a, Assign))
5492+
if assign:
5493+
syntactic_lhs = assign[-1].lhs
5494+
else:
5495+
syntactic_lhs = self.scope.get_new_name()
5496+
d_var = self._infer_type(start_set)
5497+
rhs = SetCopy(start_set)
5498+
body = []
5499+
lhs = self._assign_lhs_variable(syntactic_lhs, d_var, rhs, body)
5500+
body.append(Assign(lhs, rhs, python_ast = expr.python_ast))
5501+
try:
5502+
body += [SetIntersectionUpdate(lhs, s) for s in set_args]
5503+
except TypeError as e:
5504+
errors.report(e, symbol=expr, severity='error')
5505+
if assign:
5506+
return CodeBlock(body)
5507+
else:
5508+
self._additional_exprs[-1].extend(body)
5509+
return lhs
5510+

pyccel/stdlib/STC_Extensions/Set_extensions.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,24 @@ static inline i_type _c_MEMB(_union)(i_type* self, int n, ...) {
3838
return union_result;
3939
}
4040

41+
/**
42+
* This function represents a call to the .intersection_update() method.
43+
* @param self : The set instance to modify.
44+
* @param other : The other set in which elements must be found.
45+
*/
46+
static inline void _c_MEMB(_intersection_update)(i_type* self, i_type* other) {
47+
_c_MEMB(_iter) itr = _c_MEMB(_begin)(self);
48+
while (itr.ref)
49+
{
50+
i_key val = (*itr.ref);
51+
if (_c_MEMB(_contains)(other, val)) {
52+
_c_MEMB(_next)(&itr);
53+
} else {
54+
itr = _c_MEMB(_erase_at)(self, itr);
55+
}
56+
}
57+
}
58+
4159
#undef i_type
4260
#undef i_key
4361
#include <stc/priv/template2.h>

pyccel/stdlib/gFTL_functions/Set_extensions.inc

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,23 @@ contains
2222

2323
end function __IDENTITY(Set)_pop
2424

25+
subroutine __IDENTITY(Set)_intersection_update(this, other_set)
26+
class(Set), intent(inout) :: this
27+
class(Set), intent(in) :: other_set
28+
29+
type(SetIterator) :: iter
30+
type(SetIterator) :: last
31+
32+
iter = this%begin()
33+
last = this%end()
34+
do while (iter /= last)
35+
if (other_set % count(iter%of()) == 0) then
36+
iter = this%erase(iter)
37+
else
38+
call iter%next()
39+
end if
40+
end do
41+
42+
end subroutine __IDENTITY(Set)_intersection_update
43+
2544
#include <set/tail.inc>

tests/epyccel/test_epyccel_sets.py

Lines changed: 31 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -438,41 +438,41 @@ def union_int():
438438
assert python_result[0] == pyccel_result[0]
439439
assert set(python_result[1:]) == set(pyccel_result[1:])
440440

441-
def test_set_intersection_int(python_only_language):
441+
def test_set_intersection_int(language):
442442
def intersection_int():
443443
a = {1,2,3}
444444
b = {2,3,4}
445445
c = a.intersection(b)
446446
return len(c), c.pop(), c.pop()
447447

448-
epyccel_func = epyccel(intersection_int, language = python_only_language)
448+
epyccel_func = epyccel(intersection_int, language = language)
449449
pyccel_result = epyccel_func()
450450
python_result = intersection_int()
451451
assert python_result[0] == pyccel_result[0]
452452
assert set(python_result[1:]) == set(pyccel_result[1:])
453453

454-
def test_set_intersection_no_args(python_only_language):
454+
def test_set_intersection_no_args(language):
455455
def intersection_int():
456456
a = {1,2,3,4}
457457
c = a.intersection()
458458
a.add(5)
459459
return len(c), c.pop(), c.pop(), c.pop(), c.pop()
460460

461-
epyccel_func = epyccel(intersection_int, language = python_only_language)
461+
epyccel_func = epyccel(intersection_int, language = language)
462462
pyccel_result = epyccel_func()
463463
python_result = intersection_int()
464464
assert python_result[0] == pyccel_result[0]
465465
assert set(python_result[1:]) == set(pyccel_result[1:])
466466

467-
def test_set_intersection_2_args(python_only_language):
467+
def test_set_intersection_2_args(language):
468468
def intersection_int():
469469
a = {1,2,3,4}
470470
b = {5,6,7,2,1,3}
471471
c = {7,6,10,4,2,3,1}
472472
d = a.intersection(b, c)
473473
return len(d), d.pop(), d.pop(), d.pop()
474474

475-
epyccel_func = epyccel(intersection_int, language = python_only_language)
475+
epyccel_func = epyccel(intersection_int, language = language)
476476
pyccel_result = epyccel_func()
477477
python_result = intersection_int()
478478
assert python_result[0] == pyccel_result[0]
@@ -538,52 +538,53 @@ def union_int():
538538
assert python_result[0] == pyccel_result[0]
539539
assert set(python_result[1:]) == set(pyccel_result[1:])
540540

541-
def test_temporary_set_intersection(python_only_language):
541+
def test_temporary_set_intersection(language):
542542
def intersection_int():
543543
a = {1,2}
544544
b = {2}
545545
d = a.intersection(b).pop()
546546
return d
547547

548-
epyccel_func = epyccel(intersection_int, language = python_only_language)
548+
epyccel_func = epyccel(intersection_int, language = language)
549549
pyccel_result = epyccel_func()
550550
python_result = intersection_int()
551551
assert python_result == pyccel_result
552552

553-
def test_set_intersection_list(python_only_language):
554-
def intersection_list():
555-
a = {1.2, 2.3, 5.0}
556-
b = [1.2, 5.0, 4.0]
557-
d = a.intersection(b)
558-
return len(d), d.pop(), d.pop()
553+
def test_set_intersection_operator(language):
554+
def intersection_int():
555+
a = {1,2,3,4,8}
556+
b = {5,2,3,7,8}
557+
c = a & b
558+
return len(c), c.pop(), c.pop(), c.pop()
559559

560-
epyccel_func = epyccel(intersection_list, language = python_only_language)
560+
epyccel_func = epyccel(intersection_int, language = language)
561561
pyccel_result = epyccel_func()
562-
python_result = intersection_list()
562+
python_result = intersection_int()
563563
assert python_result[0] == pyccel_result[0]
564564
assert set(python_result[1:]) == set(pyccel_result[1:])
565565

566-
def test_set_intersection_tuple(python_only_language):
567-
def intersection_tuple():
568-
a = {True}
569-
b = (False, True)
570-
d = a.intersection(b)
571-
return len(d), d.pop()
566+
def test_set_intersection_update(language):
567+
def intersection_int():
568+
a = {1,2,3,4,8}
569+
b = {5,2,3,7,8}
570+
a.intersection_update(b)
571+
return len(a), a.pop(), a.pop(), a.pop()
572572

573-
epyccel_func = epyccel(intersection_tuple, language = python_only_language)
573+
epyccel_func = epyccel(intersection_int, language = language)
574574
pyccel_result = epyccel_func()
575-
python_result = intersection_tuple()
575+
python_result = intersection_int()
576576
assert python_result[0] == pyccel_result[0]
577577
assert set(python_result[1:]) == set(pyccel_result[1:])
578578

579-
def test_set_intersection_operator(python_only_language):
579+
def test_set_intersection_multiple_update(language):
580580
def intersection_int():
581581
a = {1,2,3,4,8}
582582
b = {5,2,3,7,8}
583-
c = a & b
584-
return len(c), c.pop(), c.pop(), c.pop()
583+
c = {10,2,20}
584+
a.intersection_update(b, c)
585+
return len(a), a.pop()
585586

586-
epyccel_func = epyccel(intersection_int, language = python_only_language)
587+
epyccel_func = epyccel(intersection_int, language = language)
587588
pyccel_result = epyccel_func()
588589
python_result = intersection_int()
589590
assert python_result[0] == pyccel_result[0]
@@ -613,14 +614,14 @@ def union_int():
613614
python_result = union_int()
614615
assert python_result == pyccel_result
615616

616-
def test_set_intersection_augoperator(python_only_language):
617+
def test_set_intersection_augoperator(language):
617618
def intersection_int():
618619
a = {1,2,3,4}
619620
b = {2,3,4}
620621
a &= b
621622
return len(a), a.pop(), a.pop(), a.pop()
622623

623-
epyccel_func = epyccel(intersection_int, language = python_only_language)
624+
epyccel_func = epyccel(intersection_int, language = language)
624625
pyccel_result = epyccel_func()
625626
python_result = intersection_int()
626627
assert python_result[0] == pyccel_result[0]

0 commit comments

Comments
 (0)