Skip to content

Commit 6abf2f4

Browse files
authored
Remove temporary workaround from dpnp.put_along_axis (#1839)
* Get rid of call_origin in dpnp.put * Removed temporary w/a in dpnp.put_along_axis
1 parent 614af33 commit 6abf2f4

File tree

3 files changed

+173
-10
lines changed

3 files changed

+173
-10
lines changed

.github/workflows/conda-package.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ env:
4949
test_usm_type.py
5050
third_party/cupy/core_tests
5151
third_party/cupy/indexing_tests/test_indexing.py
52+
third_party/cupy/lib_tests
5253
third_party/cupy/linalg_tests
5354
third_party/cupy/logic_tests
5455
third_party/cupy/manipulation_tests

dpnp/dpnp_iface_indexing.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -831,8 +831,7 @@ def put(a, ind, v, /, *, axis=None, mode="wrap"):
831831
in_a[:] = a.reshape(in_a.shape, copy=False)
832832

833833

834-
# pylint: disable=redefined-outer-name
835-
def put_along_axis(a, indices, values, axis):
834+
def put_along_axis(a, ind, values, axis):
836835
"""
837836
Put values into the destination array by matching 1d index and data slices.
838837
@@ -842,13 +841,13 @@ def put_along_axis(a, indices, values, axis):
842841
----------
843842
a : {dpnp.ndarray, usm_ndarray}, (Ni..., M, Nk...)
844843
Destination array.
845-
indices : {dpnp.ndarray, usm_ndarray}, (Ni..., J, Nk...)
844+
ind : {dpnp.ndarray, usm_ndarray}, (Ni..., J, Nk...)
846845
Indices to change along each 1d slice of `a`. This must match the
847846
dimension of input array, but dimensions in ``Ni`` and ``Nj``
848847
may be 1 to broadcast against `a`.
849848
values : {scalar, array_like}, (Ni..., J, Nk...)
850849
Values to insert at those indices. Its shape and dimension are
851-
broadcast to match that of `indices`.
850+
broadcast to match that of `ind`.
852851
axis : int
853852
The axis to take 1d slices along. If axis is ``None``, the destination
854853
array is treated as if a flattened 1d view had been created of it.
@@ -880,16 +879,12 @@ def put_along_axis(a, indices, values, axis):
880879
881880
"""
882881

883-
dpnp.check_supported_arrays_type(a, indices)
884-
885-
# TODO: remove when #1382(dpctl) is resolved
886-
if dpnp.is_supported_array_type(values) and a.dtype != values.dtype:
887-
values = values.astype(a.dtype)
882+
dpnp.check_supported_arrays_type(a, ind)
888883

889884
if axis is None:
890885
a = a.ravel()
891886

892-
a[_build_along_axis_index(a, indices, axis)] = values
887+
a[_build_along_axis_index(a, ind, axis)] = values
893888

894889

895890
def putmask(x1, mask, values):
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
import unittest
2+
3+
import numpy
4+
import pytest
5+
6+
import dpnp as cupy
7+
from tests.third_party.cupy import testing
8+
9+
10+
@testing.parameterize(*(testing.product({"axis": [0, 1, -1]})))
11+
@pytest.mark.skip("'apply_along_axis' is not implemented yet")
12+
class TestApplyAlongAxis(unittest.TestCase):
13+
@testing.numpy_cupy_array_equal()
14+
def test_simple(self, xp):
15+
a = xp.ones((20, 10), "d")
16+
return xp.apply_along_axis(len, self.axis, a)
17+
18+
@testing.for_all_dtypes(no_bool=True)
19+
@testing.numpy_cupy_array_equal()
20+
def test_3d(self, xp, dtype):
21+
a = xp.arange(27, dtype=dtype).reshape((3, 3, 3))
22+
return xp.apply_along_axis(xp.sum, self.axis, a)
23+
24+
@testing.numpy_cupy_array_equal()
25+
def test_0d_array(self, xp):
26+
27+
def sum_to_0d(x):
28+
"""Sum x, returning a 0d array of the same class"""
29+
assert x.ndim == 1
30+
return xp.squeeze(xp.sum(x, keepdims=True))
31+
32+
a = xp.ones((6, 3))
33+
return xp.apply_along_axis(sum_to_0d, self.axis, a)
34+
35+
@testing.numpy_cupy_array_equal()
36+
def test_axis_insertion_2d(self, xp):
37+
38+
def f1to2(x):
39+
"""produces an asymmetric non-square matrix from x"""
40+
assert x.ndim == 1
41+
return x[::-1] * x[1:, None]
42+
43+
# 2d insertion
44+
a2d = xp.arange(6 * 3).reshape((6, 3))
45+
return xp.apply_along_axis(f1to2, self.axis, a2d)
46+
47+
@testing.numpy_cupy_array_equal()
48+
def test_axis_insertion_3d(self, xp):
49+
50+
def f1to2(x):
51+
"""produces an asymmetric non-square matrix from x"""
52+
assert x.ndim == 1
53+
return x[::-1] * x[1:, None]
54+
55+
# 3d insertion
56+
a3d = xp.arange(6 * 5 * 3).reshape((6, 5, 3))
57+
return xp.apply_along_axis(f1to2, self.axis, a3d)
58+
59+
def test_empty1(self):
60+
# can't apply_along_axis when there's no chance to call the function
61+
def never_call(x):
62+
assert False # should never be reached
63+
64+
for xp in [numpy, cupy]:
65+
a = xp.empty((0, 0))
66+
with pytest.raises(ValueError):
67+
xp.apply_along_axis(never_call, self.axis, a)
68+
69+
def test_empty2(self):
70+
# but it's sometimes ok with some non-zero dimensions
71+
def empty_to_1(x):
72+
assert len(x) == 0
73+
return 1
74+
75+
for xp in [numpy, cupy]:
76+
shape = [10, 10]
77+
shape[self.axis] = 0
78+
shape = tuple(shape)
79+
a = xp.empty(shape)
80+
if self.axis == 0:
81+
other_axis = 1
82+
else:
83+
other_axis = 0
84+
with pytest.raises(ValueError):
85+
xp.apply_along_axis(empty_to_1, other_axis, a)
86+
87+
# okay to call along the shape 0 axis
88+
testing.assert_array_equal(
89+
xp.apply_along_axis(empty_to_1, self.axis, a), xp.ones((10,))
90+
)
91+
92+
@testing.numpy_cupy_array_equal()
93+
def test_tuple_outs(self, xp):
94+
def func(x):
95+
return x.sum(axis=-1), x.prod(axis=-1), x.max(axis=-1)
96+
97+
a = testing.shaped_arange((2, 2, 2), xp, cupy.int64)
98+
return xp.apply_along_axis(func, 1, a)
99+
100+
101+
@testing.with_requires("numpy>=1.16")
102+
@pytest.mark.skip("'apply_along_axis' is not implemented yet")
103+
def test_apply_along_axis_invalid_axis():
104+
for xp in [numpy, cupy]:
105+
a = xp.ones((8, 4))
106+
for axis in [-3, 2]:
107+
with pytest.raises(numpy.AxisError):
108+
xp.apply_along_axis(xp.sum, axis, a)
109+
110+
111+
class TestPutAlongAxis(unittest.TestCase):
112+
@testing.for_all_dtypes()
113+
@testing.numpy_cupy_array_equal()
114+
def test_put_along_axis_empty(self, xp, dtype):
115+
a = xp.array([], dtype=dtype).reshape(0, 10)
116+
i = xp.array([], dtype=xp.int64).reshape(0, 10)
117+
vals = xp.array([]).reshape(0, 10)
118+
ret = xp.put_along_axis(a, i, vals, axis=0)
119+
assert ret is None
120+
return a
121+
122+
@testing.for_all_dtypes()
123+
@testing.numpy_cupy_array_equal()
124+
def test_simple(self, xp, dtype):
125+
a = testing.shaped_arange((3, 3, 3), xp, dtype)
126+
indices_max = xp.argmax(a, axis=0, keepdims=True)
127+
ret = xp.put_along_axis(a, indices_max, 0, axis=0)
128+
assert ret is None
129+
return a
130+
131+
@testing.for_all_dtypes()
132+
def test_indices_values_arr_diff_dims(self, dtype):
133+
for xp in [numpy, cupy]:
134+
a = testing.shaped_arange((3, 3, 3), xp, dtype)
135+
i_max = xp.argmax(a, axis=0, keepdims=False)
136+
with pytest.raises(ValueError):
137+
xp.put_along_axis(a, i_max, -99, axis=1)
138+
139+
140+
@testing.parameterize(
141+
*testing.product(
142+
{
143+
"axis": [0, 1],
144+
}
145+
)
146+
)
147+
class TestPutAlongAxes(unittest.TestCase):
148+
def test_replace_max(self):
149+
arr = cupy.array([[10, 30, 20], [60, 40, 50]])
150+
indices_max = cupy.argmax(arr, axis=self.axis, keepdims=True)
151+
# replace the max with a small value
152+
cupy.put_along_axis(arr, indices_max, -99, axis=self.axis)
153+
# find the new minimum, which should max
154+
indices_min = cupy.argmin(arr, axis=self.axis, keepdims=True)
155+
testing.assert_array_equal(indices_min, indices_max)
156+
157+
158+
class TestPutAlongAxisNone(unittest.TestCase):
159+
@testing.for_all_dtypes()
160+
@testing.numpy_cupy_array_equal()
161+
def test_axis_none(self, xp, dtype):
162+
a = testing.shaped_arange((3, 3), xp, dtype)
163+
i = xp.array([1, 3])
164+
val = xp.array([99, 100])
165+
ret = xp.put_along_axis(a, i, val, axis=None)
166+
assert ret is None
167+
return a

0 commit comments

Comments
 (0)