Skip to content

Commit 07216aa

Browse files
author
Vahid Tavanashad
committed
address comments
1 parent f381aa4 commit 07216aa

File tree

1 file changed

+16
-10
lines changed

1 file changed

+16
-10
lines changed

dpnp/dpnp_iface_manipulation.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@
4949
import dpnp
5050

5151
from .dpnp_array import dpnp_array
52+
53+
# pylint: disable=no-name-in-module
54+
from .dpnp_utils import get_usm_allocations
5255
from .dpnp_utils.dpnp_utils_pad import dpnp_pad
5356

5457
__all__ = [
@@ -176,7 +179,7 @@ def _delete_with_slice(a, obj, axis):
176179
return new
177180

178181

179-
def _delete_without_slice(a, obj, axis, single_value):
182+
def _delete_without_slice(a, obj, axis, single_value, exec_q, usm_type):
180183
"""Utility function for ``dpnp.delete`` when obj is int or array of int."""
181184

182185
a, a_ndim, order, axis, slobj, n, a_shape = _calc_parameters(a, axis)
@@ -194,8 +197,8 @@ def _delete_without_slice(a, obj, axis, single_value):
194197
a_shape,
195198
dtype=a.dtype,
196199
order=order,
197-
sycl_queue=a.sycl_queue,
198-
usm_type=a.usm_type,
200+
sycl_queue=exec_q,
201+
usm_type=usm_type,
199202
)
200203
slobj[axis] = slice(None, obj)
201204
new[tuple(slobj)] = a[tuple(slobj)]
@@ -215,7 +218,7 @@ def _delete_without_slice(a, obj, axis, single_value):
215218
keep = ~obj
216219
else:
217220
keep = dpnp.ones(
218-
n, dtype=dpnp.bool, sycl_queue=a.sycl_queue, usm_type=a.usm_type
221+
n, dtype=dpnp.bool, sycl_queue=exec_q, usm_type=usm_type
219222
)
220223
keep[obj,] = False
221224

@@ -1351,7 +1354,7 @@ def delete(arr, obj, axis=None):
13511354
obj : {slice, int, array-like of ints or boolean}
13521355
Indicate indices of sub-arrays to remove along the specified axis.
13531356
Boolean indices are treated as a mask of elements to remove.
1354-
axis : int, optional
1357+
axis : {None, int}, optional
13551358
The axis along which to delete the subarray defined by `obj`.
13561359
If `axis` is ``None``, `obj` is applied to the flattened array.
13571360
Default: ``None``.
@@ -1378,7 +1381,7 @@ def delete(arr, obj, axis=None):
13781381
>>> mask[0] = mask[2] = mask[4] = False
13791382
>>> result = arr[mask,...]
13801383
1381-
is equivalent to ``np.delete(arr, [0,2,4], axis=0)``, but allows further
1384+
is equivalent to ``np.delete(arr, [0, 2, 4], axis=0)``, but allows further
13821385
use of `mask`.
13831386
13841387
Examples
@@ -1407,14 +1410,17 @@ def delete(arr, obj, axis=None):
14071410
if isinstance(obj, slice):
14081411
return _delete_with_slice(arr, obj, axis)
14091412

1413+
if dpnp.is_supported_array_type(obj):
1414+
usm_type, exec_q = get_usm_allocations([arr, obj])
1415+
else:
1416+
usm_type, exec_q = arr.usm_type, arr.sycl_queue
1417+
14101418
if isinstance(obj, (int, dpnp.integer)) and not isinstance(obj, bool):
14111419
single_value = True
14121420
else:
14131421
single_value = False
14141422
is_array = isinstance(obj, (dpnp_array, numpy.ndarray, dpt.usm_ndarray))
1415-
obj = dpnp.asarray(
1416-
obj, sycl_queue=arr.sycl_queue, usm_type=arr.usm_type
1417-
)
1423+
obj = dpnp.asarray(obj, sycl_queue=exec_q, usm_type=usm_type)
14181424
# if `obj` is originally an empty list, after converting it into
14191425
# an array, it will have float dtype, so we need to change its dtype
14201426
# to integer. However, if `obj` is originally an empty array with
@@ -1427,7 +1433,7 @@ def delete(arr, obj, axis=None):
14271433
obj = obj.item()
14281434
single_value = True
14291435

1430-
return _delete_without_slice(arr, obj, axis, single_value)
1436+
return _delete_without_slice(arr, obj, axis, single_value, exec_q, usm_type)
14311437

14321438

14331439
def dsplit(ary, indices_or_sections):

0 commit comments

Comments
 (0)