Skip to content

Commit 1387634

Browse files
committed
Advanced indices don't broadcast if 1 array passed
- _mock removed from indexing methods
1 parent f06dde5 commit 1387634

File tree

3 files changed

+25
-16
lines changed

3 files changed

+25
-16
lines changed

dpctl/tensor/_copy_utils.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -430,7 +430,7 @@ def _mock_nonzero(ary):
430430
return tuple(dpt.asarray(i, usm_type=usm_type, sycl_queue=q) for i in nz)
431431

432432

433-
def _mock_take_multi_index(ary, inds, p):
433+
def _take_multi_index(ary, inds, p):
434434
if not isinstance(ary, dpt.usm_ndarray):
435435
raise TypeError
436436
queues_ = [
@@ -439,6 +439,8 @@ def _mock_take_multi_index(ary, inds, p):
439439
usm_types_ = [
440440
ary.usm_type,
441441
]
442+
if not isinstance(inds, list) and not isinstance(inds, tuple):
443+
inds = (inds,)
442444
all_integers = True
443445
for ind in inds:
444446
queues_.append(ind.sycl_queue)
@@ -452,7 +454,8 @@ def _mock_take_multi_index(ary, inds, p):
452454
raise IndexError(
453455
"arrays used as indices must be of integer (or boolean) type"
454456
)
455-
inds = dpt.broadcast_arrays(*inds)
457+
if (len(inds) > 1):
458+
inds = dpt.broadcast_arrays(*inds)
456459
ary_ndim = ary.ndim
457460
if ary_ndim > 0:
458461
p = operator.index(p)
@@ -505,7 +508,7 @@ def _mock_place(ary, ary_mask, p, vals):
505508
return
506509

507510

508-
def _mock_put_multi_index(ary, inds, p, vals):
511+
def _put_multi_index(ary, inds, p, vals):
509512
if isinstance(vals, dpt.usm_ndarray):
510513
queues_ = [ary.sycl_queue, vals.sycl_queue]
511514
usm_types_ = [ary.usm_type, vals.usm_type]
@@ -516,6 +519,8 @@ def _mock_put_multi_index(ary, inds, p, vals):
516519
usm_types_ = [
517520
ary.usm_type,
518521
]
522+
if not isinstance(inds, list) and not isinstance(inds, tuple):
523+
inds = (inds,)
519524
all_integers = True
520525
for ind in inds:
521526
if not isinstance(ind, dpt.usm_ndarray):
@@ -536,8 +541,8 @@ def _mock_put_multi_index(ary, inds, p, vals):
536541
raise IndexError(
537542
"arrays used as indices must be of integer (or boolean) type"
538543
)
539-
540-
inds = dpt.broadcast_arrays(*inds)
544+
if (len(inds) > 1):
545+
inds = dpt.broadcast_arrays(*inds)
541546
ary_ndim = ary.ndim
542547
if ary_ndim > 0:
543548
p = operator.index(p)

dpctl/tensor/_indexing_functions.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,8 @@ def take(x, indices, /, *, axis=None, mode="clip"):
7979
)
8080
axis = 0
8181

82-
indices = dpt.broadcast_arrays(*indices)
82+
if len(indices) > 1:
83+
indices = dpt.broadcast_arrays(*indices)
8384
if x_ndim > 0:
8485
axis = operator.index(axis)
8586
axis = normalize_axis_index(axis, x_ndim)
@@ -149,10 +150,13 @@ def put(x, indices, vals, /, *, axis=None, mode="clip"):
149150

150151
# when axis is none, array is treated as 1D
151152
if axis is None:
152-
x = dpt.reshape(x, (x.size,), copy=False)
153-
axis = 0
154-
155-
indices = dpt.broadcast_arrays(*indices)
153+
try:
154+
x = dpt.reshape(x, (x.size,), copy=False)
155+
axis = 0
156+
except ValueError:
157+
raise ValueError("Cannot create 1D view of array")
158+
if len(indices) > 1:
159+
indices = dpt.broadcast_arrays(*indices)
156160
x_ndim = x.ndim
157161
if x_ndim > 0:
158162
axis = operator.index(axis)

dpctl/tensor/_usmarray.pyx

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -673,7 +673,7 @@ cdef class usm_ndarray:
673673
from ._copy_utils import (
674674
_mock_extract,
675675
_mock_nonzero,
676-
_mock_take_multi_index,
676+
_take_multi_index,
677677
)
678678
if len(adv_ind) == 1 and adv_ind[0].dtype == dpt_bool:
679679
return _mock_extract(res, adv_ind[0], adv_ind_start_p)
@@ -685,9 +685,9 @@ cdef class usm_ndarray:
685685
adv_ind_int.extend(_mock_nonzero(ind))
686686
else:
687687
adv_ind_int.append(ind)
688-
return _mock_take_multi_index(res, tuple(adv_ind_int), adv_ind_start_p)
688+
return _take_multi_index(res, tuple(adv_ind_int), adv_ind_start_p)
689689

690-
return _mock_take_multi_index(res, adv_ind, adv_ind_start_p)
690+
return _take_multi_index(res, adv_ind, adv_ind_start_p)
691691

692692

693693
def to_device(self, target):
@@ -1021,7 +1021,7 @@ cdef class usm_ndarray:
10211021
_copy_from_usm_ndarray_to_usm_ndarray,
10221022
_mock_nonzero,
10231023
_mock_place,
1024-
_mock_put_multi_index,
1024+
_put_multi_index,
10251025
)
10261026

10271027
adv_ind = _meta[3]
@@ -1064,10 +1064,10 @@ cdef class usm_ndarray:
10641064
adv_ind_int.extend(_mock_nonzero(ind))
10651065
else:
10661066
adv_ind_int.append(ind)
1067-
_mock_put_multi_index(Xv, tuple(adv_ind_int), adv_ind_start_p, rhs)
1067+
_put_multi_index(Xv, tuple(adv_ind_int), adv_ind_start_p, rhs)
10681068
return
10691069

1070-
_mock_put_multi_index(Xv, adv_ind, adv_ind_start_p, rhs)
1070+
_put_multi_index(Xv, adv_ind, adv_ind_start_p, rhs)
10711071
return
10721072

10731073

0 commit comments

Comments
 (0)