Skip to content

Commit 703871f

Browse files
committed
Break up compress to satisfy pylint
Also disable checks for protected access, as `compress` uses dpctl.tensor private functions
1 parent 51d5ea0 commit 703871f

File tree

1 file changed

+73
-69
lines changed

1 file changed

+73
-69
lines changed

dpnp/dpnp_iface_indexing.py

Lines changed: 73 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737
3838
"""
3939

40+
# pylint: disable=protected-access
41+
4042
import operator
4143

4244
import dpctl.tensor as dpt
@@ -158,6 +160,71 @@ def choose(x1, choices, out=None, mode="raise"):
158160
return call_origin(numpy.choose, x1, choices, out, mode)
159161

160162

163+
def _take_1d_index(x, inds, axis, q, usm_type, out=None):
164+
# arg validation assumed done by caller
165+
x_sh = x.shape
166+
ind0 = inds[0]
167+
axis_end = axis + 1
168+
if 0 in x_sh[axis:axis_end] and ind0.size != 0:
169+
raise IndexError("cannot take non-empty indices from an empty axis")
170+
res_sh = x_sh[:axis] + ind0.shape + x_sh[axis_end:]
171+
172+
orig_out = out
173+
if out is not None:
174+
dpnp.check_supported_arrays_type(out)
175+
out = dpnp.get_usm_ndarray(out)
176+
177+
if not out.flags.writable:
178+
raise ValueError("provided `out` array is read-only")
179+
180+
if out.shape != res_sh:
181+
raise ValueError(
182+
"The shape of input and output arrays are inconsistent. "
183+
f"Expected output shape is {res_sh}, got {out.shape}"
184+
)
185+
186+
if x.dtype != out.dtype:
187+
raise ValueError(
188+
f"Output array of type {x.dtype} is needed, " f"got {out.dtype}"
189+
)
190+
191+
if dpu.get_execution_queue((q, out.sycl_queue)) is None:
192+
raise dpu.ExecutionPlacementError(
193+
"Input and output allocation queues are not compatible"
194+
)
195+
196+
if ti._array_overlap(x, out):
197+
# Allocate a temporary buffer to avoid memory overlapping.
198+
out = dpt.empty_like(out)
199+
else:
200+
out = dpt.empty(res_sh, dtype=x.dtype, usm_type=usm_type, sycl_queue=q)
201+
202+
_manager = dpu.SequentialOrderManager[q]
203+
dep_evs = _manager.submitted_events
204+
205+
# always use wrap mode here
206+
h_ev, take_ev = ti._take(
207+
src=x,
208+
ind=inds,
209+
dst=out,
210+
axis_start=axis,
211+
mode=0,
212+
sycl_queue=q,
213+
depends=dep_evs,
214+
)
215+
_manager.add_event_pair(h_ev, take_ev)
216+
217+
if not (orig_out is None or orig_out is out):
218+
# Copy the out data from temporary buffer to original memory
219+
ht_copy_ev, cpy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
220+
src=out, dst=orig_out, sycl_queue=q, depends=[take_ev]
221+
)
222+
_manager.add_event_pair(ht_copy_ev, cpy_ev)
223+
out = orig_out
224+
225+
return out
226+
227+
161228
def compress(condition, a, axis=None, out=None):
162229
"""
163230
Return selected slices of an array along given axis.
@@ -195,8 +262,7 @@ def compress(condition, a, axis=None, out=None):
195262
if a.ndim != 1:
196263
a = dpnp.ravel(a)
197264
axis = 0
198-
else:
199-
axis = normalize_axis_index(operator.index(axis), a.ndim)
265+
axis = normalize_axis_index(operator.index(axis), a.ndim)
200266

201267
a_ary = dpnp.get_usm_ndarray(a)
202268
if not dpnp.is_supported_array_type(condition):
@@ -216,7 +282,7 @@ def compress(condition, a, axis=None, out=None):
216282
usm_types_ = [a_ary.usm_type, cond_ary.usm_type]
217283
if not cond_ary.ndim == 1:
218284
raise ValueError(
219-
"`condition` must be a 1-D array or un-nested " "sequence"
285+
"`condition` must be a 1-D array or un-nested sequence"
220286
)
221287

222288
res_usm_type = dpu.get_coerced_usm_type(usm_types_)
@@ -226,74 +292,12 @@ def compress(condition, a, axis=None, out=None):
226292
"arrays must be allocated on the same SYCL queue"
227293
)
228294

229-
inds = _nonzero_impl(cond_ary) # synchronizes
230-
231-
res_dt = a_ary.dtype
232-
ind0 = inds[0]
233-
a_sh = a_ary.shape
234-
axis_end = axis + 1
235-
if 0 in a_sh[axis:axis_end] and ind0.size != 0:
236-
raise IndexError("cannot take non-empty indices from an empty axis")
237-
res_sh = a_sh[:axis] + ind0.shape + a_sh[axis_end:]
238-
239-
orig_out = out
240-
if out is not None:
241-
dpnp.check_supported_arrays_type(out)
242-
out = dpnp.get_usm_ndarray(out)
243-
244-
if not out.flags.writable:
245-
raise ValueError("provided `out` array is read-only")
246-
247-
if out.shape != res_sh:
248-
raise ValueError(
249-
"The shape of input and output arrays are inconsistent. "
250-
f"Expected output shape is {res_sh}, got {out.shape}"
251-
)
252-
253-
if res_dt != out.dtype:
254-
raise ValueError(
255-
f"Output array of type {res_dt} is needed, " f"got {out.dtype}"
256-
)
257-
258-
if dpu.get_execution_queue((a_ary.sycl_queue, out.sycl_queue)) is None:
259-
raise dpu.ExecutionPlacementError(
260-
"Input and output allocation queues are not compatible"
261-
)
262-
263-
if ti._array_overlap(a_ary, out):
264-
# Allocate a temporary buffer to avoid memory overlapping.
265-
out = dpt.empty_like(out)
266-
else:
267-
out = dpt.empty(
268-
res_sh, dtype=res_dt, usm_type=res_usm_type, sycl_queue=exec_q
269-
)
270-
271-
if out.size == 0:
272-
return out
295+
# _nonzero_impl synchronizes and returns a tuple of usm_ndarray indices
296+
inds = _nonzero_impl(cond_ary)
273297

274-
_manager = dpu.SequentialOrderManager[exec_q]
275-
dep_evs = _manager.submitted_events
276-
277-
h_ev, take_ev = ti._take(
278-
src=a_ary,
279-
ind=inds,
280-
dst=out,
281-
axis_start=axis,
282-
mode=0,
283-
sycl_queue=exec_q,
284-
depends=dep_evs,
298+
return dpnp.get_result_array(
299+
_take_1d_index(a_ary, inds, axis, exec_q, res_usm_type, out)
285300
)
286-
_manager.add_event_pair(h_ev, take_ev)
287-
288-
if not (orig_out is None or orig_out is out):
289-
# Copy the out data from temporary buffer to original memory
290-
ht_copy_ev, cpy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
291-
src=out, dst=orig_out, sycl_queue=exec_q, depends=[take_ev]
292-
)
293-
_manager.add_event_pair(ht_copy_ev, cpy_ev)
294-
out = orig_out
295-
296-
return dpnp.get_result_array(out)
297301

298302

299303
def diag_indices(n, ndim=2, device=None, usm_type="device", sycl_queue=None):

0 commit comments

Comments
 (0)