Skip to content

Commit 19691ca

Browse files
Implemented dpctl.tensor.place as per documented behavior.
1 parent 3ced89a commit 19691ca

File tree

1 file changed

+21
-6
lines changed

1 file changed

+21
-6
lines changed

dpctl/tensor/_indexing_functions.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@
2121

2222
import dpctl
2323
import dpctl.tensor as dpt
24-
from dpctl.tensor._tensor_impl import _put, _take
24+
import dpctl.tensor._tensor_impl as ti
2525

26-
from ._copy_utils import _extract_impl, _nonzero_impl, _place_impl
26+
from ._copy_utils import _extract_impl, _nonzero_impl
2727

2828

2929
def take(x, indices, /, *, axis=None, mode="clip"):
@@ -95,7 +95,7 @@ def take(x, indices, /, *, axis=None, mode="clip"):
9595
res_shape, dtype=x.dtype, usm_type=res_usm_type, sycl_queue=exec_q
9696
)
9797

98-
hev, _ = _take(x, indices, res, axis, mode, sycl_queue=exec_q)
98+
hev, _ = ti._take(x, indices, res, axis, mode, sycl_queue=exec_q)
9999
hev.wait()
100100

101101
return res
@@ -175,7 +175,7 @@ def put(x, indices, vals, /, *, axis=None, mode="clip"):
175175

176176
vals = dpt.broadcast_to(vals, val_shape)
177177

178-
hev, _ = _put(x, indices, vals, axis, mode, sycl_queue=exec_q)
178+
hev, _ = ti._put(x, indices, vals, axis, mode, sycl_queue=exec_q)
179179
hev.wait()
180180

181181

@@ -265,8 +265,23 @@ def place(arr, mask, vals):
265265
raise dpctl.utils.ExecutionPlacementError
266266
if arr.shape != mask.shape or vals.ndim != 1:
267267
raise ValueError("Array sizes are not as required")
268-
# FIXME
269-
_place_impl(arr, mask, vals, axis=0)
268+
cumsum = dpt.empty(mask.size, dtype="i8", sycl_queue=exec_q)
269+
nz_count = ti.mask_positions(mask, cumsum, sycl_queue=exec_q)
270+
if nz_count == 0:
271+
return
272+
if vals.dtype == arr.dtype:
273+
rhs = vals
274+
else:
275+
rhs = dpt.astype(vals, arr.dtype)
276+
hev, _ = ti._place(
277+
dst=arr,
278+
cumsum=cumsum,
279+
axis_start=0,
280+
axis_end=mask.ndim,
281+
rhs=rhs,
282+
sycl_queue=exec_q,
283+
)
284+
hev.wait()
270285

271286

272287
def nonzero(arr):

0 commit comments

Comments
 (0)