|
21 | 21 |
|
22 | 22 | import dpctl
|
23 | 23 | import dpctl.tensor as dpt
|
24 |
| -from dpctl.tensor._tensor_impl import _put, _take |
| 24 | +import dpctl.tensor._tensor_impl as ti |
25 | 25 |
|
26 |
| -from ._copy_utils import _extract_impl, _nonzero_impl, _place_impl |
| 26 | +from ._copy_utils import _extract_impl, _nonzero_impl |
27 | 27 |
|
28 | 28 |
|
29 | 29 | def take(x, indices, /, *, axis=None, mode="clip"):
|
@@ -95,7 +95,7 @@ def take(x, indices, /, *, axis=None, mode="clip"):
|
95 | 95 | res_shape, dtype=x.dtype, usm_type=res_usm_type, sycl_queue=exec_q
|
96 | 96 | )
|
97 | 97 |
|
98 |
| - hev, _ = _take(x, indices, res, axis, mode, sycl_queue=exec_q) |
| 98 | + hev, _ = ti._take(x, indices, res, axis, mode, sycl_queue=exec_q) |
99 | 99 | hev.wait()
|
100 | 100 |
|
101 | 101 | return res
|
@@ -175,7 +175,7 @@ def put(x, indices, vals, /, *, axis=None, mode="clip"):
|
175 | 175 |
|
176 | 176 | vals = dpt.broadcast_to(vals, val_shape)
|
177 | 177 |
|
178 |
| - hev, _ = _put(x, indices, vals, axis, mode, sycl_queue=exec_q) |
| 178 | + hev, _ = ti._put(x, indices, vals, axis, mode, sycl_queue=exec_q) |
179 | 179 | hev.wait()
|
180 | 180 |
|
181 | 181 |
|
@@ -265,8 +265,23 @@ def place(arr, mask, vals):
|
265 | 265 | raise dpctl.utils.ExecutionPlacementError
|
266 | 266 | if arr.shape != mask.shape or vals.ndim != 1:
|
267 | 267 | raise ValueError("Array sizes are not as required")
|
268 |
| - # FIXME |
269 |
| - _place_impl(arr, mask, vals, axis=0) |
| 268 | + cumsum = dpt.empty(mask.size, dtype="i8", sycl_queue=exec_q) |
| 269 | + nz_count = ti.mask_positions(mask, cumsum, sycl_queue=exec_q) |
| 270 | + if nz_count == 0: |
| 271 | + return |
| 272 | + if vals.dtype == arr.dtype: |
| 273 | + rhs = vals |
| 274 | + else: |
| 275 | + rhs = dpt.astype(vals, arr.dtype) |
| 276 | + hev, _ = ti._place( |
| 277 | + dst=arr, |
| 278 | + cumsum=cumsum, |
| 279 | + axis_start=0, |
| 280 | + axis_end=mask.ndim, |
| 281 | + rhs=rhs, |
| 282 | + sycl_queue=exec_q, |
| 283 | + ) |
| 284 | + hev.wait() |
270 | 285 |
|
271 | 286 |
|
272 | 287 | def nonzero(arr):
|
|
0 commit comments