Skip to content

Commit 5152d2b

Browse files
EmilyBourneyguclu
andauthored
Fix array to slice assignment (pyccel#1810)
Fix assigning an array to a slice and add tests. Fixes pyccel#1218 This is achieved by fixing the indexing in the case where the variable on the left-hand side of the assignment has a higher rank than the variable on the right hand side. --------- Co-authored-by: Yaman Güçlü <[email protected]>
1 parent f17d55b commit 5152d2b

File tree

4 files changed

+112
-20
lines changed

4 files changed

+112
-20
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ All notable changes to this project will be documented in this file.
2525
- #1795 : Fix bug when returning slices in C.
2626
- #1732 : Fix multidimensional list indexing in Python.
2727
- #1785 : Add missing cast when creating an array of booleans from non-boolean values.
28+
- #1218 : Fix bug when assigning an array to a slice in Fortran.
2829

2930
### Changed
3031
- #1720 : functions with the `@inline` decorator are no longer exposed to Python in the shared library.

pyccel/ast/utilities.py

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -494,14 +494,14 @@ def collect_loops(block, indices, new_index, language_has_vectors = False, resul
494494
# Loop over indexes, inserting until the expression can be evaluated
495495
# in the desired language
496496
new_level = 0
497-
for index in range(-rank,0):
497+
for index_depth in range(-rank, 0):
498498
new_level += 1
499499
# If an index exists at the same depth, reuse it if not create one
500-
if rank+index >= len(indices):
501-
indices.append(new_index(PythonNativeInt(),'i'))
502-
index_var = indices[rank+index]
503-
new_vars = [insert_index(v, index, index_var) for v in new_vars]
504-
handled_funcs = [insert_index(v, index, index_var) for v in handled_funcs]
500+
if rank+index_depth >= len(indices):
501+
indices.append(new_index(PythonNativeInt(), 'i'))
502+
index = indices[rank+index_depth]
503+
new_vars = [insert_index(v, index_depth, index) for v in new_vars]
504+
handled_funcs = [insert_index(v, index_depth, index) for v in handled_funcs]
505505
if compatible_operation(*new_vars, *handled_funcs, language_has_vectors = language_has_vectors):
506506
break
507507

@@ -541,22 +541,34 @@ def collect_loops(block, indices, new_index, language_has_vectors = False, resul
541541
current_level = new_level
542542

543543
elif isinstance(line, Assign) and isinstance(line.lhs, IndexedElement) \
544-
and isinstance(line.rhs, (PythonTuple, NumpyArray)) and not language_has_vectors:
545-
544+
and isinstance(line.rhs, (PythonTuple, NumpyArray)):
546545
lhs = line.lhs
547546
rhs = line.rhs
548-
if isinstance(rhs, NumpyArray):
549-
rhs = rhs.arg
550-
551-
lhs_rank = lhs.rank
552-
553-
new_assigns = [Assign(
554-
insert_index(expr=lhs,
555-
pos = -lhs_rank,
556-
index_var = LiteralInteger(j)),
557-
rj) # lhs[j] = rhs[j]
558-
for j, rj in enumerate(rhs)]
559-
collect_loops(new_assigns, indices, new_index, language_has_vectors, result = result)
547+
if lhs.rank > rhs.rank:
548+
for index_depth in range(lhs.rank-rhs.rank):
549+
# If an index exists at the same depth, reuse it if not create one
550+
if index_depth >= len(indices):
551+
indices.append(new_index(PythonNativeInt(), 'i'))
552+
index = indices[index_depth]
553+
lhs = insert_index(lhs, index_depth, index)
554+
collect_loops([Assign(lhs, rhs)], indices, new_index, language_has_vectors, result = result)
555+
556+
elif not language_has_vectors:
557+
if isinstance(rhs, NumpyArray):
558+
rhs = rhs.arg
559+
560+
lhs_rank = lhs.rank
561+
562+
new_assigns = [Assign(
563+
insert_index(expr=lhs,
564+
pos = -lhs_rank,
565+
index_var = LiteralInteger(j)),
566+
rj) # lhs[j] = rhs[j]
567+
for j, rj in enumerate(rhs)]
568+
collect_loops(new_assigns, indices, new_index, language_has_vectors, result = result)
569+
570+
else:
571+
result.append(line)
560572

561573
elif isinstance(line, Assign) and isinstance(line.rhs, Concatenate):
562574
lhs = line.lhs

tests/epyccel/modules/arrays.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1581,6 +1581,29 @@ def array_2d_C_slice_stride_23(a : 'int[:,:]'):
15811581
b = a[::d, ::c]
15821582
return np.sum(b), b[0][0], b[-1][-1], len(b), len(b[0])
15831583

1584+
#==============================================================================
1585+
# Slice assignment
1586+
#==============================================================================
1587+
1588+
def copy_to_slice_issue_1218(n : int):
1589+
from numpy import zeros, array
1590+
x = 1
1591+
arr = zeros((2, n))
1592+
arr[0:x, 0:6:2] = array([2, 5, 6])
1593+
return arr
1594+
1595+
def copy_to_slice_1(a : 'float[:]', b : 'float[:]'):
1596+
a[1:-1] = b
1597+
1598+
def copy_to_slice_2(a : 'float[:,:]', b : 'float[:]'):
1599+
a[:, 1:-1] = b
1600+
1601+
def copy_to_slice_3(a : 'float[:,:]', b : 'float[:]'):
1602+
a[:, 0] = b
1603+
1604+
def copy_to_slice_4(a : 'float[:]', b : 'float[:]'):
1605+
a[::2] = b
1606+
15841607
#==============================================================================
15851608
# ARITHMETIC OPERATIONS
15861609
#==============================================================================

tests/epyccel/test_arrays.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3488,6 +3488,62 @@ def test_array_2d_C_slice_stride_23(language):
34883488
f2 = epyccel(f1, language = language)
34893489
assert f1(a) == f2(a)
34903490

3491+
#==============================================================================
3492+
# TEST : Slice assignment
3493+
#==============================================================================
3494+
def test_copy_to_slice_issue_1218(language):
3495+
pyth_f = arrays.copy_to_slice_issue_1218
3496+
epyc_f = epyccel(pyth_f, language = language)
3497+
3498+
n = 10
3499+
pyth_arr = pyth_f(n)
3500+
epyc_arr = epyc_f(n)
3501+
check_array_equal(pyth_arr, epyc_arr)
3502+
3503+
def test_copy_to_slice_1(language):
3504+
pyth_f = arrays.copy_to_slice_1
3505+
epyc_f = epyccel(pyth_f, language = language)
3506+
3507+
pyth_a = np.arange(10, dtype=float)
3508+
epyc_a = pyth_a.copy()
3509+
b = np.arange(20, 28, dtype=float)
3510+
pyth_f(pyth_a, b)
3511+
epyc_f(epyc_a, b)
3512+
check_array_equal(pyth_a, epyc_a)
3513+
3514+
def test_copy_to_slice_2(language):
3515+
pyth_f = arrays.copy_to_slice_2
3516+
epyc_f = epyccel(pyth_f, language = language)
3517+
3518+
pyth_a = np.arange(20, dtype=float).reshape(2, 10)
3519+
epyc_a = pyth_a.copy()
3520+
b = np.arange(20, 28, dtype=float)
3521+
pyth_f(pyth_a, b)
3522+
epyc_f(epyc_a, b)
3523+
check_array_equal(pyth_a, epyc_a)
3524+
3525+
def test_copy_to_slice_3(language):
3526+
pyth_f = arrays.copy_to_slice_3
3527+
epyc_f = epyccel(pyth_f, language = language)
3528+
3529+
pyth_a = np.arange(20, dtype=float).reshape(4, 5)
3530+
epyc_a = pyth_a.copy()
3531+
b = np.arange(20, 24, dtype=float)
3532+
pyth_f(pyth_a, b)
3533+
epyc_f(epyc_a, b)
3534+
check_array_equal(pyth_a, epyc_a)
3535+
3536+
def test_copy_to_slice_4(language):
3537+
pyth_f = arrays.copy_to_slice_4
3538+
epyc_f = epyccel(pyth_f, language = language)
3539+
3540+
pyth_a = np.arange(10, dtype=float)
3541+
epyc_a = pyth_a.copy()
3542+
b = np.arange(20, 25, dtype=float)
3543+
pyth_f(pyth_a, b)
3544+
epyc_f(epyc_a, b)
3545+
check_array_equal(pyth_a, epyc_a)
3546+
34913547
#==============================================================================
34923548
# TEST : arithmetic operations
34933549
#==============================================================================

0 commit comments

Comments
 (0)