Skip to content

Commit 167d654

Browse files
committed
Break up compress to satisfy pylint
Also disable checks for protected access, as `compress` uses dpctl.tensor private functions
1 parent 51d5ea0 commit 167d654

File tree

1 file changed

+81
-69
lines changed

1 file changed

+81
-69
lines changed

dpnp/dpnp_iface_indexing.py

Lines changed: 81 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,11 @@
3737
3838
"""
3939

40+
# pylint: disable=protected-access
41+
4042
import operator
4143

44+
import dpctl
4245
import dpctl.tensor as dpt
4346
import dpctl.tensor._tensor_impl as ti
4447
import dpctl.utils as dpu
@@ -158,6 +161,78 @@ def choose(x1, choices, out=None, mode="raise"):
158161
return call_origin(numpy.choose, x1, choices, out, mode)
159162

160163

164+
def _take_1d_index(
165+
x: dpt.usm_ndarray,
166+
inds: tuple[dpt.usm_ndarray],
167+
axis: int,
168+
q: dpctl.SyclQueue,
169+
usm_type: str,
170+
out: dpt.usm_ndarray | None = None,
171+
) -> dpt.usm_ndarray:
172+
# arg validation assumed done by caller
173+
x_sh = x.shape
174+
ind0 = inds[0]
175+
axis_end = axis + 1
176+
if 0 in x_sh[axis:axis_end] and ind0.size != 0:
177+
raise IndexError("cannot take non-empty indices from an empty axis")
178+
res_sh = x_sh[:axis] + ind0.shape + x_sh[axis_end:]
179+
180+
orig_out = out
181+
if out is not None:
182+
dpnp.check_supported_arrays_type(out)
183+
out = dpnp.get_usm_ndarray(out)
184+
185+
if not out.flags.writable:
186+
raise ValueError("provided `out` array is read-only")
187+
188+
if out.shape != res_sh:
189+
raise ValueError(
190+
"The shape of input and output arrays are inconsistent. "
191+
f"Expected output shape is {res_sh}, got {out.shape}"
192+
)
193+
194+
if x.dtype != out.dtype:
195+
raise ValueError(
196+
f"Output array of type {x.dtype} is needed, " f"got {out.dtype}"
197+
)
198+
199+
if dpu.get_execution_queue((q, out.sycl_queue)) is None:
200+
raise dpu.ExecutionPlacementError(
201+
"Input and output allocation queues are not compatible"
202+
)
203+
204+
if ti._array_overlap(x, out):
205+
# Allocate a temporary buffer to avoid memory overlapping.
206+
out = dpt.empty_like(out)
207+
else:
208+
out = dpt.empty(res_sh, dtype=x.dtype, usm_type=usm_type, sycl_queue=q)
209+
210+
_manager = dpu.SequentialOrderManager[q]
211+
dep_evs = _manager.submitted_events
212+
213+
# always use wrap mode here
214+
h_ev, take_ev = ti._take(
215+
src=x,
216+
ind=inds,
217+
dst=out,
218+
axis_start=axis,
219+
mode=0,
220+
sycl_queue=q,
221+
depends=dep_evs,
222+
)
223+
_manager.add_event_pair(h_ev, take_ev)
224+
225+
if not (orig_out is None or orig_out is out):
226+
# Copy the out data from temporary buffer to original memory
227+
ht_copy_ev, cpy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
228+
src=out, dst=orig_out, sycl_queue=q, depends=[take_ev]
229+
)
230+
_manager.add_event_pair(ht_copy_ev, cpy_ev)
231+
out = orig_out
232+
233+
return out
234+
235+
161236
def compress(condition, a, axis=None, out=None):
162237
"""
163238
Return selected slices of an array along given axis.
@@ -195,8 +270,7 @@ def compress(condition, a, axis=None, out=None):
195270
if a.ndim != 1:
196271
a = dpnp.ravel(a)
197272
axis = 0
198-
else:
199-
axis = normalize_axis_index(operator.index(axis), a.ndim)
273+
axis = normalize_axis_index(operator.index(axis), a.ndim)
200274

201275
a_ary = dpnp.get_usm_ndarray(a)
202276
if not dpnp.is_supported_array_type(condition):
@@ -216,7 +290,7 @@ def compress(condition, a, axis=None, out=None):
216290
usm_types_ = [a_ary.usm_type, cond_ary.usm_type]
217291
if not cond_ary.ndim == 1:
218292
raise ValueError(
219-
"`condition` must be a 1-D array or un-nested " "sequence"
293+
"`condition` must be a 1-D array or un-nested sequence"
220294
)
221295

222296
res_usm_type = dpu.get_coerced_usm_type(usm_types_)
@@ -226,74 +300,12 @@ def compress(condition, a, axis=None, out=None):
226300
"arrays must be allocated on the same SYCL queue"
227301
)
228302

229-
inds = _nonzero_impl(cond_ary) # synchronizes
230-
231-
res_dt = a_ary.dtype
232-
ind0 = inds[0]
233-
a_sh = a_ary.shape
234-
axis_end = axis + 1
235-
if 0 in a_sh[axis:axis_end] and ind0.size != 0:
236-
raise IndexError("cannot take non-empty indices from an empty axis")
237-
res_sh = a_sh[:axis] + ind0.shape + a_sh[axis_end:]
238-
239-
orig_out = out
240-
if out is not None:
241-
dpnp.check_supported_arrays_type(out)
242-
out = dpnp.get_usm_ndarray(out)
243-
244-
if not out.flags.writable:
245-
raise ValueError("provided `out` array is read-only")
246-
247-
if out.shape != res_sh:
248-
raise ValueError(
249-
"The shape of input and output arrays are inconsistent. "
250-
f"Expected output shape is {res_sh}, got {out.shape}"
251-
)
252-
253-
if res_dt != out.dtype:
254-
raise ValueError(
255-
f"Output array of type {res_dt} is needed, " f"got {out.dtype}"
256-
)
257-
258-
if dpu.get_execution_queue((a_ary.sycl_queue, out.sycl_queue)) is None:
259-
raise dpu.ExecutionPlacementError(
260-
"Input and output allocation queues are not compatible"
261-
)
262-
263-
if ti._array_overlap(a_ary, out):
264-
# Allocate a temporary buffer to avoid memory overlapping.
265-
out = dpt.empty_like(out)
266-
else:
267-
out = dpt.empty(
268-
res_sh, dtype=res_dt, usm_type=res_usm_type, sycl_queue=exec_q
269-
)
270-
271-
if out.size == 0:
272-
return out
303+
# _nonzero_impl synchronizes and returns a tuple of usm_ndarray indices
304+
inds = _nonzero_impl(cond_ary)
273305

274-
_manager = dpu.SequentialOrderManager[exec_q]
275-
dep_evs = _manager.submitted_events
276-
277-
h_ev, take_ev = ti._take(
278-
src=a_ary,
279-
ind=inds,
280-
dst=out,
281-
axis_start=axis,
282-
mode=0,
283-
sycl_queue=exec_q,
284-
depends=dep_evs,
306+
return dpnp.get_result_array(
307+
_take_1d_index(a_ary, inds, axis, exec_q, res_usm_type, out)
285308
)
286-
_manager.add_event_pair(h_ev, take_ev)
287-
288-
if not (orig_out is None or orig_out is out):
289-
# Copy the out data from temporary buffer to original memory
290-
ht_copy_ev, cpy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
291-
src=out, dst=orig_out, sycl_queue=exec_q, depends=[take_ev]
292-
)
293-
_manager.add_event_pair(ht_copy_ev, cpy_ev)
294-
out = orig_out
295-
296-
return dpnp.get_result_array(out)
297309

298310

299311
def diag_indices(n, ndim=2, device=None, usm_type="device", sycl_queue=None):

0 commit comments

Comments
 (0)