Skip to content

Commit d0eb7cf

Browse files
committed
Take and put tweaks
- take_multi_index and put_multi_index logic for 0D arrays removed, adjusted a test accordingly - take, put, take_multi_index, and put_multi_index axis type check and normalization only reassigns axis once
1 parent 1387634 commit d0eb7cf

File tree

3 files changed

+12
-20
lines changed

3 files changed

+12
-20
lines changed

dpctl/tensor/_copy_utils.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -454,16 +454,12 @@ def _take_multi_index(ary, inds, p):
454454
raise IndexError(
455455
"arrays used as indices must be of integer (or boolean) type"
456456
)
457-
if (len(inds) > 1):
457+
if len(inds) > 1:
458458
inds = dpt.broadcast_arrays(*inds)
459459
ary_ndim = ary.ndim
460-
if ary_ndim > 0:
461-
p = operator.index(p)
462-
p = normalize_axis_index(p, ary_ndim)
460+
p = normalize_axis_index(operator.index(p), ary_ndim)
463461

464-
res_shape = ary.shape[:p] + inds[0].shape + ary.shape[p + len(inds) :]
465-
else:
466-
res_shape = inds[0].shape
462+
res_shape = ary.shape[:p] + inds[0].shape + ary.shape[p + len(inds) :]
467463
res_usm_type = dpctl.utils.get_coerced_usm_type(usm_types_)
468464
res = dpt.empty(
469465
res_shape, dtype=ary.dtype, usm_type=res_usm_type, sycl_queue=exec_q
@@ -541,15 +537,12 @@ def _put_multi_index(ary, inds, p, vals):
541537
raise IndexError(
542538
"arrays used as indices must be of integer (or boolean) type"
543539
)
544-
if (len(inds) > 1):
540+
if len(inds) > 1:
545541
inds = dpt.broadcast_arrays(*inds)
546542
ary_ndim = ary.ndim
547-
if ary_ndim > 0:
548-
p = operator.index(p)
549-
p = normalize_axis_index(p, ary_ndim)
550-
vals_shape = ary.shape[:p] + inds[0].shape + ary.shape[p + len(inds) :]
551-
else:
552-
vals_shape = inds[0].shape
543+
544+
p = normalize_axis_index(operator.index(p), ary_ndim)
545+
vals_shape = ary.shape[:p] + inds[0].shape + ary.shape[p + len(inds) :]
553546

554547
vals_usm_type = dpctl.utils.get_coerced_usm_type(usm_types_)
555548
if not isinstance(vals, dpt.usm_ndarray):

dpctl/tensor/_indexing_functions.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,7 @@ def take(x, indices, /, *, axis=None, mode="clip"):
8282
if len(indices) > 1:
8383
indices = dpt.broadcast_arrays(*indices)
8484
if x_ndim > 0:
85-
axis = operator.index(axis)
86-
axis = normalize_axis_index(axis, x_ndim)
85+
axis = normalize_axis_index(operator.index(axis), x_ndim)
8786
res_shape = (
8887
x.shape[:axis] + indices[0].shape + x.shape[axis + len(indices) :]
8988
)
@@ -154,13 +153,12 @@ def put(x, indices, vals, /, *, axis=None, mode="clip"):
154153
x = dpt.reshape(x, (x.size,), copy=False)
155154
axis = 0
156155
except ValueError:
157-
raise ValueError("Cannot create 1D view of array")
156+
raise ValueError("Cannot create 1D view of input array")
158157
if len(indices) > 1:
159158
indices = dpt.broadcast_arrays(*indices)
160159
x_ndim = x.ndim
161160
if x_ndim > 0:
162-
axis = operator.index(axis)
163-
axis = normalize_axis_index(axis, x_ndim)
161+
axis = normalize_axis_index(operator.index(axis), x_ndim)
164162

165163
val_shape = (
166164
x.shape[:axis] + indices[0].shape + x.shape[axis + len(indices) :]

dpctl/tests/test_usm_ndarray_indexing.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -546,7 +546,8 @@ def test_put_0d_val(data_dt):
546546
assert_array_equal(np.asarray(2, dtype=data_dt), dpt.asnumpy(x[0]))
547547

548548
x = dpt.asarray(5, dtype=data_dt, sycl_queue=q)
549-
x[ind] = 2
549+
val = 2
550+
dpt.put(x, ind, val)
550551
assert_array_equal(np.asarray(2, dtype=data_dt), dpt.asnumpy(x))
551552

552553

0 commit comments

Comments
 (0)