Skip to content

Commit 3987c17

Browse files
authored
Better reshape handling (#185)
1 parent 23e9b88 commit 3987c17

File tree

3 files changed

+12
-2
lines changed

3 files changed

+12
-2
lines changed

numba_dpcomp/numba_dpcomp/mlir/numpy/funcs.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -583,7 +583,16 @@ def dtype_impl(builder, arg):
583583

584584
@register_func("array.reshape")
585585
@register_func("numpy.reshape", numpy.reshape)
586-
def reshape_impl(builder, arg, new_shape):
586+
def reshape_impl(builder, arg, *new_shape):
587+
if len(new_shape) == 1:
588+
new_shape = new_shape[0]
589+
590+
# TODO: better check
591+
if isinstance(new_shape, tuple):
592+
for s in new_shape:
593+
if isinstance(s, int) and s < 0:
594+
return # not supported for now
595+
587596
return builder.reshape(arg, new_shape)
588597

589598

numba_dpcomp/numba_dpcomp/mlir/tests/test_numba_parfor.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,6 @@ def _gen_tests():
116116
"test_tuple_concat", # tuple concat
117117
"test_two_d_array_reduction_with_float_sizes", # np.array
118118
"test_two_d_array_reduction_reuse", # np.arange
119-
"test_parfor_slice21", # unsupported reshape
120119
"test_parfor_array_access_lower_slice", # np.arange
121120
"test_size_assertion", # AssertionError not raised
122121
"test_parfor_slice18", # np.arange

numba_dpcomp/numba_dpcomp/mlir/tests/test_numpy.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1004,7 +1004,9 @@ def py_func(a):
10041004
"lambda a: a.reshape((a.size,))",
10051005
"lambda a: a.reshape((a.size,1))",
10061006
"lambda a: a.reshape((1, a.size))",
1007+
"lambda a: a.reshape(1, a.size)",
10071008
"lambda a: a.reshape((1, a.size, 1))",
1009+
"lambda a: a.reshape(1, a.size, 1)",
10081010
"lambda a: np.reshape(a, a.size)",
10091011
"lambda a: np.reshape(a, (a.size,))",
10101012
"lambda a: np.reshape(a, (a.size,1))",

0 commit comments

Comments
 (0)