Skip to content

Commit 20aab16

Browse files
Add Python Support for set initialisation with set() (pyccel#1901)
Add support for the `set()` function in Python. This fixes pyccel#1893. In this PR, the `PythonSetFunction` class is introduced to handle calls to the `set()` function, along with the addition of related tests.
1 parent 28f555d commit 20aab16

File tree

4 files changed

+125
-7
lines changed

4 files changed

+125
-7
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ All notable changes to this project will be documented in this file.
2121
- #1844 : Add line numbers and code to errors from built-in function calls.
2222
- \[INTERNALS\] Added `container_rank` property to `ast.datatypes.PyccelType` objects.
2323
- \[DEVELOPER\] Added an improved traceback to the developer-mode errors for errors in function calls.
24+
- #1893 : Add Python support for set initialisation with `set()`.
2425

2526
### Fixed
2627

pyccel/ast/builtins.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
'PythonRange',
5151
'PythonReal',
5252
'PythonSet',
53+
'PythonSetFunction',
5354
'PythonSum',
5455
'PythonTuple',
5556
'PythonTupleFunction',
@@ -774,6 +775,36 @@ def is_homogeneous(self):
774775
"""
775776
return True
776777

778+
779+
class PythonSetFunction(PyccelFunction):
780+
"""
781+
Class representing a call to the `set` function.
782+
783+
Class representing a call to the `set` function. This is
784+
different to the `{,}` syntax as it only takes one argument
785+
and unpacks any variables.
786+
787+
Parameters
788+
----------
789+
arg : TypedAstNode
790+
The argument passed to the function call.
791+
"""
792+
793+
__slots__ = ('_shape', '_class_type')
794+
name = 'set'
795+
def __new__(cls, arg):
796+
if isinstance(arg.class_type, HomogeneousSetType):
797+
return arg
798+
elif isinstance(arg, (PythonList, PythonSet, PythonTuple)):
799+
return PythonSet(*arg)
800+
else:
801+
return super().__new__(cls)
802+
803+
def __init__(self, copied_obj):
804+
self._class_type = copied_obj.class_type
805+
self._shape = copied_obj.shape
806+
super().__init__(copied_obj)
807+
777808
#==============================================================================
778809
class PythonMap(PyccelFunction):
779810
"""
@@ -1234,4 +1265,5 @@ def print_string(self):
12341265
'str' : LiteralString,
12351266
'type' : PythonType,
12361267
'tuple' : PythonTupleFunction,
1268+
'set' : PythonSetFunction
12371269
}

pyccel/parser/semantic.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from pyccel.utilities.strings import random_string
2222
from pyccel.ast.basic import PyccelAstNode, TypedAstNode, ScopedAstNode
2323

24-
from pyccel.ast.builtins import PythonPrint, PythonTupleFunction
24+
from pyccel.ast.builtins import PythonPrint, PythonTupleFunction, PythonSetFunction
2525
from pyccel.ast.builtins import PythonComplex
2626
from pyccel.ast.builtins import builtin_functions_dict, PythonImag, PythonReal
2727
from pyccel.ast.builtins import PythonList, PythonConjugate , PythonSet
@@ -64,7 +64,7 @@
6464
from pyccel.ast.datatypes import PrimitiveIntegerType, HomogeneousListType, StringType, SymbolicType
6565
from pyccel.ast.datatypes import PythonNativeBool, PythonNativeInt, PythonNativeFloat
6666
from pyccel.ast.datatypes import DataTypeFactory, PrimitiveFloatingPointType
67-
from pyccel.ast.datatypes import InhomogeneousTupleType, HomogeneousTupleType
67+
from pyccel.ast.datatypes import InhomogeneousTupleType, HomogeneousTupleType, HomogeneousSetType
6868
from pyccel.ast.datatypes import PrimitiveComplexType, FixedSizeNumericType
6969

7070
from pyccel.ast.functionalexpr import FunctionalSum, FunctionalMax, FunctionalMin, GeneratorComprehension, FunctionalFor
@@ -146,6 +146,12 @@
146146
errors = Errors()
147147
pyccel_stage = PyccelStage()
148148

149+
type_container = {
150+
PythonTupleFunction : HomogeneousTupleType,
151+
PythonList : HomogeneousListType,
152+
PythonSetFunction : HomogeneousSetType
153+
}
154+
149155
#==============================================================================
150156

151157
def _get_name(var):
@@ -1920,18 +1926,16 @@ def _get_indexed_type(self, base, args, expr):
19201926
raise errors.report(f"Unknown annotation base {base}\n"+PYCCEL_RESTRICTION_TODO,
19211927
severity='fatal', symbol=expr)
19221928
rank = 1
1923-
if len(args) == 2 and args[1] is LiteralEllipsis():
1929+
if len(args) == 2 and args[1] is LiteralEllipsis() or len(args) == 1:
19241930
syntactic_annotation = args[0]
19251931
if not isinstance(syntactic_annotation, SyntacticTypeAnnotation):
19261932
pyccel_stage.set_stage('syntactic')
19271933
syntactic_annotation = SyntacticTypeAnnotation(dtype=syntactic_annotation)
19281934
pyccel_stage.set_stage('semantic')
19291935
internal_datatypes = self._visit(syntactic_annotation)
19301936
type_annotations = []
1931-
if dtype_cls is PythonTupleFunction:
1932-
class_type = HomogeneousTupleType
1933-
elif dtype_cls is PythonList:
1934-
class_type = HomogeneousListType
1937+
if dtype_cls in type_container :
1938+
class_type = type_container[dtype_cls]
19351939
else:
19361940
raise errors.report(f"Unknown annotation base {base}\n"+PYCCEL_RESTRICTION_TODO,
19371941
severity='fatal', symbol=expr)

tests/epyccel/test_epyccel_sets.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,3 +283,84 @@ def update_tuple_as_arg():
283283
pyccel_result = epyccel_update()
284284
python_result = update_tuple_as_arg()
285285
assert python_result == pyccel_result
286+
287+
def test_set_with_list(language):
288+
def set_With_list():
289+
a = [1.6, 6.3, 7.2]
290+
b = set(a)
291+
return b
292+
293+
epyc_set_With_list = epyccel(set_With_list, language = language)
294+
pyccel_result = epyc_set_With_list()
295+
python_result = set_With_list()
296+
assert isinstance(python_result, type(pyccel_result))
297+
assert python_result == pyccel_result
298+
299+
def test_set_with_tuple(language):
300+
def set_With_tuple():
301+
a = (1j, 6j, 7j)
302+
b = set(a)
303+
return b
304+
305+
epyc_set_With_tuple = epyccel(set_With_tuple, language = language)
306+
pyccel_result = epyc_set_With_tuple()
307+
python_result = set_With_tuple()
308+
assert isinstance(python_result, type(pyccel_result))
309+
assert python_result == pyccel_result
310+
311+
def test_set_with_set(language):
312+
def set_With_set():
313+
a = {True, False, True} #pylint: disable=duplicate-value
314+
b = set(a)
315+
return b
316+
317+
epyc_set_With_set = epyccel(set_With_set, language = language)
318+
pyccel_result = epyc_set_With_set()
319+
python_result = set_With_set()
320+
assert isinstance(python_result, type(pyccel_result))
321+
assert python_result == pyccel_result
322+
323+
def test_init_with_set(language):
324+
def init_with_set():
325+
b = set({4.6, 7.9, 2.5})
326+
return b
327+
328+
epyc_init_with_set = epyccel(init_with_set, language = language)
329+
pyccel_result = epyc_init_with_set()
330+
python_result = init_with_set()
331+
assert isinstance(python_result, type(pyccel_result))
332+
assert python_result == pyccel_result
333+
334+
def test_set_init_with_list(language):
335+
def init_with_list():
336+
b = set([4.6, 7.9, 2.5])
337+
return b
338+
339+
epyc_init_with_list = epyccel(init_with_list, language = language)
340+
pyccel_result = epyc_init_with_list()
341+
python_result = init_with_list()
342+
assert isinstance(python_result, type(pyccel_result))
343+
assert python_result == pyccel_result
344+
345+
346+
def test_set_copy_from_arg1(language):
347+
def copy_from_arg1(a : 'list[float]'):
348+
b = set(a)
349+
return b
350+
a = [2.5, 1.4, 9.2]
351+
epyc_copy_from_arg = epyccel(copy_from_arg1, language = language)
352+
pyccel_result = epyc_copy_from_arg(a)
353+
python_result = copy_from_arg1(a)
354+
assert isinstance(python_result, type(pyccel_result))
355+
assert python_result == pyccel_result
356+
357+
def test_set_copy_from_arg2(language):
358+
def copy_from_arg2(a : 'set[float]'):
359+
b = set(a)
360+
return b
361+
a = {2.5, 1.4, 9.2}
362+
epyc_copy_from_arg = epyccel(copy_from_arg2, language = language)
363+
pyccel_result = epyc_copy_from_arg(a)
364+
python_result = copy_from_arg2(a)
365+
assert isinstance(python_result, type(pyccel_result))
366+
assert python_result == pyccel_result

0 commit comments

Comments
 (0)