Skip to content

Commit 04e4b51

Browse files
Ensure that rhs can be a scalar or numpy array
1 parent efcd9cb commit 04e4b51

File tree

1 file changed

+33
-5
lines changed

1 file changed

+33
-5
lines changed

dpctl/tensor/_copy_utils.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,8 @@ def _mock_nonzero(ary):
428428

429429

430430
def _mock_take_multi_index(ary, inds, p):
431+
if not isinstance(ary, dpt.usm_ndarray):
432+
raise TypeError
431433
queues_ = [
432434
ary.sycl_queue,
433435
]
@@ -459,9 +461,15 @@ def _mock_take_multi_index(ary, inds, p):
459461

460462

461463
def _mock_place(ary, ary_mask, p, vals):
464+
if not isinstance(ary, dpt.usm_ndarray):
465+
raise TypeError
466+
if not isinstance(ary_mask, dpt.usm_ndarray):
467+
raise TypeError
462468
exec_q = dpctl.utils.get_execution_queue(
463-
(ary.sycl_queue, ary_mask.sycl_queue, vals.sycl_queue)
469+
(ary.sycl_queue, ary_mask.sycl_queue)
464470
)
471+
if exec_q is not None and isinstance(vals, dpt.usm_ndarray):
472+
exec_q = dpctl.utils.get_execution_queue((exec_q, vals.sycl_queue))
465473
if exec_q is None:
466474
raise dpctl.utils.ExecutionPlacementError(
467475
"Can not automatically determine where to allocate the "
@@ -472,17 +480,32 @@ def _mock_place(ary, ary_mask, p, vals):
472480

473481
ary_np = dpt.asnumpy(ary)
474482
mask_np = dpt.asnumpy(ary_mask)
475-
vals_np = dpt.asnumpy(vals)
483+
if isinstance(vals, dpt.usm_ndarray) or hasattr(
484+
vals, "__sycl_usm_array_interface__"
485+
):
486+
vals_np = dpt.asnumpy(vals)
487+
else:
488+
vals_np = vals
476489
ary_np[(slice(None),) * p + (mask_np,)] = vals_np
477490
ary[...] = ary_np
478491
return
479492

480493

481494
def _mock_put_multi_index(ary, inds, p, vals):
482-
queues_ = [ary.sycl_queue, vals.sycl_queue]
483-
usm_types_ = [ary.usm_type, vals.usm_type]
495+
if isinstance(vals, dpt.ums_ndarray):
496+
queues_ = [ary.sycl_queue, vals.sycl_queue]
497+
usm_types_ = [ary.usm_type, vals.usm_type]
498+
else:
499+
queues_ = [
500+
ary.sycl_queue,
501+
]
502+
usm_types_ = [
503+
ary.usm_type,
504+
]
484505
all_integers = True
485506
for ind in inds:
507+
if not isinstance(ind, dpt.usm_ndarray):
508+
raise TypeError
486509
queues_.append(ind.sycl_queue)
487510
usm_types_.append(ind.usm_type)
488511
if all_integers:
@@ -500,7 +523,12 @@ def _mock_put_multi_index(ary, inds, p, vals):
500523
"arrays used as indices must be of integer (or boolean) type"
501524
)
502525
ary_np = dpt.asnumpy(ary)
503-
vals_np = dpt.asnumpy(vals)
526+
if isinstance(vals, dpt.usm_ndarray) or hasattr(
527+
vals, "__sycl_usm_array_interface__"
528+
):
529+
vals_np = dpt.asnumpy(vals)
530+
else:
531+
vals_np = vals
504532
ind_np = (slice(None),) * p + tuple(dpt.asnumpy(ind) for ind in inds)
505533
ary_np[ind_np] = vals_np
506534
ary[...] = ary_np

0 commit comments

Comments
 (0)