Skip to content

Commit 07304be

Browse files
committed
Break up compress to satisfy pylint
Also disable checks for protected access, as `compress` uses dpctl.tensor private functions
1 parent 466edd9 commit 07304be

File tree

1 file changed

+73
-69
lines changed

1 file changed

+73
-69
lines changed

dpnp/dpnp_iface_indexing.py

Lines changed: 73 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737
3838
"""
3939

40+
# pylint: disable=protected-access
41+
4042
import operator
4143

4244
import dpctl.tensor as dpt
@@ -159,6 +161,71 @@ def choose(x1, choices, out=None, mode="raise"):
159161
return call_origin(numpy.choose, x1, choices, out, mode)
160162

161163

164+
def _take_1d_index(x, inds, axis, q, usm_type, out=None):
165+
# arg validation assumed done by caller
166+
x_sh = x.shape
167+
ind0 = inds[0]
168+
axis_end = axis + 1
169+
if 0 in x_sh[axis:axis_end] and ind0.size != 0:
170+
raise IndexError("cannot take non-empty indices from an empty axis")
171+
res_sh = x_sh[:axis] + ind0.shape + x_sh[axis_end:]
172+
173+
orig_out = out
174+
if out is not None:
175+
dpnp.check_supported_arrays_type(out)
176+
out = dpnp.get_usm_ndarray(out)
177+
178+
if not out.flags.writable:
179+
raise ValueError("provided `out` array is read-only")
180+
181+
if out.shape != res_sh:
182+
raise ValueError(
183+
"The shape of input and output arrays are inconsistent. "
184+
f"Expected output shape is {res_sh}, got {out.shape}"
185+
)
186+
187+
if x.dtype != out.dtype:
188+
raise ValueError(
189+
f"Output array of type {x.dtype} is needed, " f"got {out.dtype}"
190+
)
191+
192+
if dpu.get_execution_queue((q, out.sycl_queue)) is None:
193+
raise dpu.ExecutionPlacementError(
194+
"Input and output allocation queues are not compatible"
195+
)
196+
197+
if ti._array_overlap(x, out):
198+
# Allocate a temporary buffer to avoid memory overlapping.
199+
out = dpt.empty_like(out)
200+
else:
201+
out = dpt.empty(res_sh, dtype=x.dtype, usm_type=usm_type, sycl_queue=q)
202+
203+
_manager = dpu.SequentialOrderManager[q]
204+
dep_evs = _manager.submitted_events
205+
206+
# always use wrap mode here
207+
h_ev, take_ev = ti._take(
208+
src=x,
209+
ind=inds,
210+
dst=out,
211+
axis_start=axis,
212+
mode=0,
213+
sycl_queue=q,
214+
depends=dep_evs,
215+
)
216+
_manager.add_event_pair(h_ev, take_ev)
217+
218+
if not (orig_out is None or orig_out is out):
219+
# Copy the out data from temporary buffer to original memory
220+
ht_copy_ev, cpy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
221+
src=out, dst=orig_out, sycl_queue=q, depends=[take_ev]
222+
)
223+
_manager.add_event_pair(ht_copy_ev, cpy_ev)
224+
out = orig_out
225+
226+
return out
227+
228+
162229
def compress(condition, a, axis=None, out=None):
163230
"""
164231
Return selected slices of an array along given axis.
@@ -196,8 +263,7 @@ def compress(condition, a, axis=None, out=None):
196263
if a.ndim != 1:
197264
a = dpnp.ravel(a)
198265
axis = 0
199-
else:
200-
axis = normalize_axis_index(operator.index(axis), a.ndim)
266+
axis = normalize_axis_index(operator.index(axis), a.ndim)
201267

202268
a_ary = dpnp.get_usm_ndarray(a)
203269
if not dpnp.is_supported_array_type(condition):
@@ -217,7 +283,7 @@ def compress(condition, a, axis=None, out=None):
217283
usm_types_ = [a_ary.usm_type, cond_ary.usm_type]
218284
if not cond_ary.ndim == 1:
219285
raise ValueError(
220-
"`condition` must be a 1-D array or un-nested " "sequence"
286+
"`condition` must be a 1-D array or un-nested sequence"
221287
)
222288

223289
res_usm_type = dpu.get_coerced_usm_type(usm_types_)
@@ -227,74 +293,12 @@ def compress(condition, a, axis=None, out=None):
227293
"arrays must be allocated on the same SYCL queue"
228294
)
229295

230-
inds = _nonzero_impl(cond_ary) # synchronizes
231-
232-
res_dt = a_ary.dtype
233-
ind0 = inds[0]
234-
a_sh = a_ary.shape
235-
axis_end = axis + 1
236-
if 0 in a_sh[axis:axis_end] and ind0.size != 0:
237-
raise IndexError("cannot take non-empty indices from an empty axis")
238-
res_sh = a_sh[:axis] + ind0.shape + a_sh[axis_end:]
239-
240-
orig_out = out
241-
if out is not None:
242-
dpnp.check_supported_arrays_type(out)
243-
out = dpnp.get_usm_ndarray(out)
244-
245-
if not out.flags.writable:
246-
raise ValueError("provided `out` array is read-only")
247-
248-
if out.shape != res_sh:
249-
raise ValueError(
250-
"The shape of input and output arrays are inconsistent. "
251-
f"Expected output shape is {res_sh}, got {out.shape}"
252-
)
253-
254-
if res_dt != out.dtype:
255-
raise ValueError(
256-
f"Output array of type {res_dt} is needed, " f"got {out.dtype}"
257-
)
258-
259-
if dpu.get_execution_queue((a_ary.sycl_queue, out.sycl_queue)) is None:
260-
raise dpu.ExecutionPlacementError(
261-
"Input and output allocation queues are not compatible"
262-
)
263-
264-
if ti._array_overlap(a_ary, out):
265-
# Allocate a temporary buffer to avoid memory overlapping.
266-
out = dpt.empty_like(out)
267-
else:
268-
out = dpt.empty(
269-
res_sh, dtype=res_dt, usm_type=res_usm_type, sycl_queue=exec_q
270-
)
271-
272-
if out.size == 0:
273-
return out
296+
# _nonzero_impl synchronizes and returns a tuple of usm_ndarray indices
297+
inds = _nonzero_impl(cond_ary)
274298

275-
_manager = dpu.SequentialOrderManager[exec_q]
276-
dep_evs = _manager.submitted_events
277-
278-
h_ev, take_ev = ti._take(
279-
src=a_ary,
280-
ind=inds,
281-
dst=out,
282-
axis_start=axis,
283-
mode=0,
284-
sycl_queue=exec_q,
285-
depends=dep_evs,
299+
return dpnp.get_result_array(
300+
_take_1d_index(a_ary, inds, axis, exec_q, res_usm_type, out)
286301
)
287-
_manager.add_event_pair(h_ev, take_ev)
288-
289-
if not (orig_out is None or orig_out is out):
290-
# Copy the out data from temporary buffer to original memory
291-
ht_copy_ev, cpy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
292-
src=out, dst=orig_out, sycl_queue=exec_q, depends=[take_ev]
293-
)
294-
_manager.add_event_pair(ht_copy_ev, cpy_ev)
295-
out = orig_out
296-
297-
return dpnp.get_result_array(out)
298302

299303

300304
def diag_indices(n, ndim=2, device=None, usm_type="device", sycl_queue=None):

0 commit comments

Comments
 (0)