Skip to content

Commit 1cde2ad

Browse files
authored
numpy.flip (#155)
1 parent a4c1351 commit 1cde2ad

File tree

2 files changed

+52
-3
lines changed

2 files changed

+52
-3
lines changed

numba_dpcomp/numba_dpcomp/mlir/numpy/funcs.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,12 @@ def linalg_index_impl(builder, dim):
5252
if isinstance(dim, int):
5353
return builder.linalg_index(dim)
5454

55+
def _fix_axis(axis, num_dims):
56+
if axis < 0:
57+
axis = axis + num_dims
58+
assert axis >= 0 and axis < num_dims
59+
return axis
60+
5561
def _array_reduce(builder, arg, axis, body, get_init_value):
5662
if axis is None:
5763
shape = arg.shape
@@ -68,9 +74,7 @@ def _array_reduce(builder, arg, axis, body, get_init_value):
6874
elif isinstance(axis, int):
6975
shape = arg.shape
7076
num_dims = len(shape)
71-
if axis < 0:
72-
axis += num_dims
73-
assert axis >= 0 and axis < num_dims
77+
axis = _fix_axis(axis, num_dims)
7478
iterators = [('reduction' if i == axis else 'parallel') for i in range(num_dims)]
7579
dims1 = ','.join(['d%s' % i for i in range(num_dims)])
7680
dims2 = ','.join(['d%s' % i for i in range(num_dims) if i != axis])
@@ -88,6 +92,24 @@ def sum_impl(builder, arg, axis=None):
8892
return _array_reduce(builder, arg, axis, lambda a, b: a + b, lambda b, t: 0)
8993

9094

95+
@register_func('numpy.flip', numpy.flip)
96+
def flip_impl(builder, arg, axis=None):
97+
shape = arg.shape
98+
num_dims = len(shape)
99+
if axis is None:
100+
axis = (True,) * num_dims
101+
elif isinstance(axis, int):
102+
axis = _fix_axis(axis, num_dims)
103+
l = [False] * num_dims
104+
l[axis] = True
105+
axis = tuple(l)
106+
else:
107+
return
108+
109+
offsets = [0 if not axis[i] else shape[i] - 1 for i in range(num_dims)]
110+
strides = [1 if not axis[i] else - 1 for i in range(num_dims)]
111+
return builder.subview(arg, offsets, shape, strides)
112+
91113
def _get_numpy_type(builder, dtype):
92114
types = [
93115
(builder.int8, numpy.int8),

numba_dpcomp/numba_dpcomp/mlir/tests/test_numpy.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,33 @@ def test_reduce_axis(py_func, arr):
300300
jit_func = njit(py_func)
301301
assert_equal(py_func(arr), jit_func(arr))
302302

303+
@parametrize_function_variants("py_func", [
304+
'lambda a: np.flip(a)',
305+
])
306+
@pytest.mark.parametrize("arr", [
307+
np.array([1,2,3,4,5,6], dtype=np.int32),
308+
np.array([[1,2,3],[4,5,6]], dtype=np.int32),
309+
np.array([[[1,2,3],[4,5,6]]], dtype=np.int32),
310+
])
311+
def test_flip1(py_func, arr):
312+
jit_func = njit(py_func)
313+
assert_equal(py_func(arr), jit_func(arr))
314+
315+
@parametrize_function_variants("py_func", [
316+
'lambda a: np.flip(a, axis=0)',
317+
'lambda a: np.flip(a, axis=1)',
318+
'lambda a: np.flip(a, axis=-1)',
319+
'lambda a: np.flip(a, axis=-2)',
320+
])
321+
@pytest.mark.parametrize("arr", [
322+
np.array([[[1,2,3],[4,5,6]]], dtype=np.int32),
323+
np.array([[[1,2,3],[4,5,6]]], dtype=np.float32),
324+
])
325+
@pytest.mark.xfail
326+
def test_flip2(py_func, arr):
327+
jit_func = njit(py_func)
328+
assert_equal(py_func(arr), jit_func(arr))
329+
303330
def test_sum_add():
304331
def py_func(a, b):
305332
return np.add(a, b).sum()

0 commit comments

Comments
 (0)