Skip to content

Commit 386a150

Browse files
authored
Fix pointer assignment for list/set/dict (pyccel#2038)
Fix the recognition of an alias assign for an expression such as `a = b` where `a` and `b` are lists/sets/dicts. Fixes pyccel#2008. Fix the printing of a list/set/dict pointer
1 parent 9f7fb7c commit 386a150

File tree

8 files changed

+48
-11
lines changed

8 files changed

+48
-11
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ All notable changes to this project will be documented in this file.
7171
- #1979 : Fix memory leaks in C due to homogeneous container redefinition.
7272
- #1972 : Simplified `printf` statement for Literal String.
7373
- #2026 : Fix missing loop in slice assignment.
74+
- #2008 : Ensure list/set/dict assignment is recognised as a reference.
7475

7576
### Changed
7677

pyccel/ast/builtins.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -732,10 +732,8 @@ class PythonListFunction(PyccelFunction):
732732
def __new__(cls, arg = None):
733733
if arg is None:
734734
return PythonList()
735-
elif isinstance(arg, PythonList):
736-
return arg
737735
elif isinstance(arg.shape[0], LiteralInteger):
738-
return PythonList(*[arg[i] for i in range(arg.shape[0])])
736+
return PythonList(*arg)
739737
else:
740738
return super().__new__(cls)
741739

@@ -838,8 +836,6 @@ class PythonSetFunction(PyccelFunction):
838836
def __new__(cls, arg = None):
839837
if arg is None:
840838
return PythonSet()
841-
elif isinstance(arg.class_type, HomogeneousSetType):
842-
return arg
843839
elif isinstance(arg, (PythonList, PythonSet, PythonTuple)):
844840
return PythonSet(*arg)
845841
else:

pyccel/codegen/printing/ccode.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,7 @@ def is_c_pointer(self, a):
359359

360360
if not isinstance(a, Variable):
361361
return False
362-
return (a.is_alias and not isinstance(a.class_type, HomogeneousContainerType)) \
362+
return (a.is_alias and not isinstance(a.class_type, (HomogeneousTupleType, NumpyNDArrayType))) \
363363
or a.is_optional or \
364364
any(a is bi for b in self._additional_args for bi in b)
365365

@@ -1363,10 +1363,9 @@ def get_declare_type(self, expr):
13631363
if rank > 0:
13641364
if isinstance(expr.class_type, (HomogeneousSetType, HomogeneousListType, DictType)):
13651365
dtype = self.get_c_type(expr.class_type)
1366-
return dtype
1367-
if isinstance(expr.class_type, CStackArray):
1366+
elif isinstance(expr.class_type, CStackArray):
13681367
return self.get_c_type(expr.class_type.element_type)
1369-
if isinstance(expr.class_type,(HomogeneousTupleType, NumpyNDArrayType)):
1368+
elif isinstance(expr.class_type, (HomogeneousTupleType, NumpyNDArrayType)):
13701369
if expr.rank > 15:
13711370
errors.report(UNSUPPORTED_ARRAY_RANK, symbol=expr, severity='fatal')
13721371
self.add_import(c_imports['ndarrays'])
@@ -1741,6 +1740,8 @@ def _print_Allocate(self, expr):
17411740

17421741
def _print_Deallocate(self, expr):
17431742
if isinstance(expr.variable.class_type, (HomogeneousListType, HomogeneousSetType, DictType)):
1743+
if expr.variable.is_alias:
1744+
return ''
17441745
variable_address = self._print(ObjectAddress(expr.variable))
17451746
container_type = self.get_c_type(expr.variable.class_type)
17461747
return f'{container_type}_drop({variable_address});\n'

pyccel/codegen/printing/fcode.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1981,7 +1981,7 @@ def _print_AliasAssign(self, expr):
19811981
# TODO improve
19821982
op = '=>'
19831983
shape_code = ''
1984-
if lhs.rank > 0:
1984+
if isinstance(lhs.class_type, (NumpyNDArrayType, HomogeneousTupleType)):
19851985
shape_code = ', '.join('0:' for i in range(lhs.rank))
19861986
shape_code = '({s_c})'.format(s_c = shape_code)
19871987

pyccel/parser/semantic.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1401,7 +1401,8 @@ def _ensure_target(self, rhs, d_lhs):
14011401
d_lhs['memory_handling'] = 'alias'
14021402
rhs.internal_var.is_target = True
14031403

1404-
if isinstance(rhs, Variable) and (rhs.is_ndarray or isinstance(rhs.class_type, CustomDataType)):
1404+
if isinstance(rhs, Variable) and (rhs.rank > 0 or isinstance(rhs.class_type, CustomDataType)) \
1405+
and not isinstance(rhs.class_type, (TupleType, StringType)):
14051406
d_lhs['memory_handling'] = 'alias'
14061407
rhs.is_target = not rhs.is_alias
14071408

tests/epyccel/test_epyccel_dicts.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,3 +191,16 @@ def dict_contains():
191191
python_result = dict_contains()
192192
assert isinstance(python_result, type(pyccel_result))
193193
assert python_result == pyccel_result
194+
195+
def test_dict_ptr(python_only_language):
196+
def dict_ptr():
197+
a = {1:1.0, 2:2.0, 3:3.0}
198+
b = a
199+
c = b.pop(2)
200+
return len(a), len(b), c
201+
202+
epyc_func = epyccel(dict_ptr, language = python_only_language)
203+
pyccel_result = epyc_func()
204+
python_result = dict_ptr()
205+
assert isinstance(python_result, type(pyccel_result))
206+
assert python_result == pyccel_result

tests/epyccel/test_epyccel_lists.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -756,3 +756,16 @@ def list_contains():
756756
python_result = list_contains()
757757
assert isinstance(python_result, type(pyccel_result))
758758
assert python_result == pyccel_result
759+
760+
def test_dict_ptr(language):
761+
def list_ptr():
762+
a = [1, 3, 4, 7, 10, 3]
763+
b = a
764+
b.append(22)
765+
return len(a), len(b)
766+
767+
epyc_list_ptr = epyccel(list_ptr, language = language)
768+
pyccel_result = epyc_list_ptr()
769+
python_result = list_ptr()
770+
assert isinstance(python_result, type(pyccel_result))
771+
assert python_result == pyccel_result

tests/epyccel/test_epyccel_sets.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -520,3 +520,15 @@ def union_int():
520520
pyccel_result = epyccel_func()
521521
python_result = union_int()
522522
assert python_result == pyccel_result
523+
524+
def test_set_ptr(language):
525+
def set_ptr():
526+
a = {1,2,3,4,5,6,7,8}
527+
b = a
528+
b.pop()
529+
return len(a), len(b)
530+
531+
epyccel_func = epyccel(set_ptr, language = language)
532+
pyccel_result = epyccel_func()
533+
python_result = set_ptr()
534+
assert python_result == pyccel_result

0 commit comments

Comments
 (0)