Skip to content

Commit e248f8a

Browse files
authored
Fix array unpacking (pyccel#1793)
Add missing `PythonTuple` cast to fix array unpacking. Fixes pyccel#1792. Add tests including tests returning slices. Fix C bug returning slices. Fixes pyccel#1795.
1 parent 59879b7 commit e248f8a

File tree

5 files changed

+98
-27
lines changed

5 files changed

+98
-27
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ All notable changes to this project will be documented in this file.
1717
- Fix some cases where a Python built-in type is returned in place of a NumPy type.
1818
- Stop printing numbers with more decimal digits than their precision.
1919
- Allow printing the result of a function returning multiple objects of different types.
20+
- #1792 : Fix array unpacking.
21+
- #1795 : Fix bug when returning slices in C.
2022

2123
### Changed
2224
- #1720 : functions with the `@inline` decorator are no longer exposed to Python in the shared library.

pyccel/codegen/printing/ccode.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2131,8 +2131,7 @@ def _print_AliasAssign(self, expr):
21312131

21322132
# the below condition handles the case of reassinging a pointer to an array view.
21332133
# setting the pointer's is_view attribute to false so it can be ignored by the free_pointer function.
2134-
if not self.is_c_pointer(lhs_var) and \
2135-
isinstance(lhs_var, Variable) and lhs_var.is_ndarray:
2134+
if isinstance(lhs_var, Variable) and lhs_var.is_ndarray and not lhs_var.is_optional:
21362135
rhs = self._print(rhs_var)
21372136

21382137
if isinstance(rhs_var, Variable) and rhs_var.is_ndarray:
@@ -2143,12 +2142,12 @@ def _print_AliasAssign(self, expr):
21432142
return 'transpose_alias_assign({}, {});\n'.format(lhs, rhs)
21442143
else:
21452144
lhs = self._print(lhs_var)
2146-
return '{} = {};\n'.format(lhs, rhs)
2145+
return f'{lhs} = {rhs};\n'
21472146
else:
21482147
lhs = self._print(lhs_address)
21492148
rhs = self._print(rhs_address)
21502149

2151-
return '{} = {};\n'.format(lhs, rhs)
2150+
return f'{lhs} = {rhs};\n'
21522151

21532152
def _print_For(self, expr):
21542153
self.set_scope(expr.scope)

pyccel/parser/semantic.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@
7474
from pyccel.ast.headers import FunctionHeader, MethodHeader, Header
7575
from pyccel.ast.headers import MacroFunction, MacroVariable
7676

77-
from pyccel.ast.internals import PyccelInternalFunction, Slice, PyccelSymbol
77+
from pyccel.ast.internals import PyccelInternalFunction, Slice, PyccelSymbol, PyccelArrayShapeElement
7878
from pyccel.ast.itertoolsext import Product
7979

8080
from pyccel.ast.literals import LiteralTrue, LiteralFalse
@@ -3363,7 +3363,7 @@ def _visit_Assign(self, expr):
33633363
new_lhs.append( self._assign_lhs_variable(l, d_var[i].copy(), rhs, new_expressions, isinstance(expr, AugAssign)) )
33643364
lhs = PythonTuple(*new_lhs)
33653365

3366-
elif d_var['shape'][0]==n:
3366+
elif d_var['shape'][0]==n or isinstance(d_var['shape'][0], PyccelArrayShapeElement):
33673367
new_lhs = []
33683368
new_rhs = []
33693369

@@ -3372,7 +3372,7 @@ def _visit_Assign(self, expr):
33723372
new_rhs.append(r)
33733373

33743374
lhs = PythonTuple(*new_lhs)
3375-
rhs = new_rhs
3375+
rhs = PythonTuple(*new_rhs)
33763376
else:
33773377
errors.report(WRONG_NUMBER_OUTPUT_ARGS, symbol=expr, severity='error')
33783378
return None

tests/epyccel/modules/arrays.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1718,14 +1718,6 @@ def arr_arange_7(arr : 'int[:,:]'):
17181718
for i in range(n):
17191719
arr[i] = np.arange(i, i+m)
17201720

1721-
def iterate_slice(i : int):
1722-
import numpy as np
1723-
a = np.arange(15)
1724-
res = 0
1725-
for ai in a[:i]:
1726-
res += ai
1727-
return res
1728-
17291721
#==============================================================================
17301722
# NUMPY SUM
17311723
#==============================================================================
@@ -1753,3 +1745,32 @@ def multiple_np_linspace():
17531745
y = np.linspace(0, 4, 128)
17541746
z = np.linspace(0, 8, 128)
17551747
return x[0] + y[1] + z[2] + linspace_index
1748+
1749+
#==============================================================================
1750+
# Iteration
1751+
#==============================================================================
1752+
1753+
def iterate_slice(i : int):
1754+
from numpy import arange
1755+
a = arange(15)
1756+
res = 0
1757+
for ai in a[:i]:
1758+
res += ai
1759+
return res
1760+
1761+
@template('T', ['int[:]', 'int[:,:]', 'int[:,:,:]', 'int[:,:](order=F)', 'int[:,:,:](order=F)'])
1762+
def unpack_array(arr : 'T'):
1763+
x, y, z = arr[:]
1764+
return x, y, z
1765+
1766+
def unpack_array_of_known_size():
1767+
from numpy import array
1768+
arr = array([1,2,3], dtype='float64')
1769+
x, y, z = arr[:]
1770+
return x, y, z
1771+
1772+
def unpack_array_2D_of_known_size():
1773+
from numpy import array
1774+
arr = array([[1,2,3], [4,5,6], [7,8,9]], dtype='float64')
1775+
x, y, z = arr[:]
1776+
return x.sum(), y.sum(), z.sum()

tests/epyccel/test_arrays.py

Lines changed: 61 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3647,11 +3647,6 @@ def test_numpy_arange_into_slice(language):
36473647
f2(x)
36483648
np.testing.assert_allclose(x, x_expected, rtol=RTOL, atol=ATOL)
36493649

3650-
def test_iterate_slice(language):
3651-
f1 = arrays.iterate_slice
3652-
f2 = epyccel(f1, language = language)
3653-
i = randint(2, 10)
3654-
assert f1(i) == f2(i)
36553650
##==============================================================================
36563651
## TEST NESTED ARRAYS INITIALIZATION WITH ORDER C
36573652
##==============================================================================
@@ -3975,11 +3970,65 @@ def test_array_ndmin_2_order(language):
39753970
check_array_equal(f1(d), f2(d))
39763971
check_array_equal(f1(e), f2(e))
39773972

3973+
##==============================================================================
3974+
## TEST ITERATION
3975+
##==============================================================================
3976+
3977+
def test_iterate_slice(language):
3978+
f1 = arrays.iterate_slice
3979+
f2 = epyccel(f1, language = language)
3980+
i = randint(2, 10)
3981+
assert f1(i) == f2(i)
3982+
3983+
@pytest.mark.parametrize( 'language', (
3984+
pytest.param("fortran", marks = [
3985+
pytest.mark.xfail(reason=("Cannot return a non-contiguous slice. See #1796")),
3986+
pytest.mark.fortran]),
3987+
pytest.param("c", marks = pytest.mark.c),
3988+
pytest.param("python", marks = pytest.mark.python)
3989+
)
3990+
)
3991+
def test_unpacking(language):
3992+
f1 = arrays.unpack_array
3993+
f2 = epyccel(f1, language = language)
3994+
3995+
arr = np.arange(3, dtype=int)
3996+
assert f1(arr) == f2(arr)
3997+
3998+
arr = np.arange(12, dtype=int).reshape((3,4))
3999+
x1, y1, z1 = f1(arr)
4000+
x2, y2, z2 = f2(arr)
4001+
check_array_equal(x1, x2)
4002+
check_array_equal(y1, y2)
4003+
check_array_equal(z1, z2)
39784004

3979-
#def teardown_module():
3980-
# import os, glob
3981-
# dirname = os.path.dirname( arrays.__file__ )
3982-
# pattern = os.path.join( dirname, '__epyccel__*' )
3983-
# filelist = glob.glob( pattern )
3984-
# for f in filelist:
3985-
# os.remove( f )
4005+
arr = np.arange(24, dtype=int).reshape((3,4,2))
4006+
x1, y1, z1 = f1(arr)
4007+
x2, y2, z2 = f2(arr)
4008+
check_array_equal(x1, x2)
4009+
check_array_equal(y1, y2)
4010+
check_array_equal(z1, z2)
4011+
4012+
arr = np.arange(12, dtype=int).reshape((3,4), order='F')
4013+
x1, y1, z1 = f1(arr)
4014+
x2, y2, z2 = f2(arr)
4015+
check_array_equal(x1, x2)
4016+
check_array_equal(y1, y2)
4017+
check_array_equal(z1, z2)
4018+
4019+
arr = np.arange(24, dtype=int).reshape((3,4,2), order='F')
4020+
x1, y1, z1 = f1(arr)
4021+
x2, y2, z2 = f2(arr)
4022+
check_array_equal(x1, x2)
4023+
check_array_equal(y1, y2)
4024+
check_array_equal(z1, z2)
4025+
4026+
def test_unpacking_of_known_size(language):
4027+
f1 = arrays.unpack_array_of_known_size
4028+
f2 = epyccel(f1, language = language)
4029+
assert f1() == f2()
4030+
4031+
def test_unpacking_2D_of_known_size(language):
4032+
f1 = arrays.unpack_array_2D_of_known_size
4033+
f2 = epyccel(f1, language = language)
4034+
assert f1() == f2()

0 commit comments

Comments
 (0)