Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 45 additions & 48 deletions dpnp/dpnp_iface_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,9 @@
import math
import operator
import warnings
from collections import namedtuple
from typing import NamedTuple

import dpctl
import dpctl.tensor as dpt
import numpy
from dpctl.tensor._numpy_helper import AxisError, normalize_axis_index
Expand All @@ -55,20 +56,20 @@
from .dpnp_utils import get_usm_allocations
from .dpnp_utils.dpnp_utils_pad import dpnp_pad

Parameters = namedtuple(
"Parameters_insert_delete",
[
"a",
"a_ndim",
"order",
"axis",
"slobj",
"n",
"a_shape",
"exec_q",
"usm_type",
],
)

class InsertDeleteParams(NamedTuple):
"""Parameters used for ``dpnp.delete`` and ``dpnp.insert``."""

a: dpnp_array
a_ndim: int
order: str
axis: int
slobj: list
n: int
a_shape: list
exec_q: dpctl.SyclQueue
usm_type: str


__all__ = [
"append",
Expand Down Expand Up @@ -140,7 +141,7 @@ def _check_stack_arrays(arrays):
def _delete_with_slice(params, obj, axis):
"""Utility function for ``dpnp.delete`` when obj is slice."""

a, a_ndim, order, axis, slobj, n, a_shape, exec_q, usm_type = params
a, a_ndim, order, axis, slobj, n, newshape, exec_q, usm_type = params

start, stop, step = obj.indices(n)
xr = range(start, stop, step)
Expand All @@ -155,11 +156,11 @@ def _delete_with_slice(params, obj, axis):
start = xr[-1]
stop = xr[0] + 1

a_shape[axis] -= num_del
newshape[axis] -= num_del
new = dpnp.empty(
a_shape,
dtype=a.dtype,
newshape,
order=order,
dtype=a.dtype,
sycl_queue=exec_q,
usm_type=usm_type,
)
Expand Down Expand Up @@ -201,7 +202,7 @@ def _delete_with_slice(params, obj, axis):
def _delete_without_slice(params, obj, axis, single_value):
"""Utility function for ``dpnp.delete`` when obj is int or array of int."""

a, a_ndim, order, axis, slobj, n, a_shape, exec_q, usm_type = params
a, a_ndim, order, axis, slobj, n, newshape, exec_q, usm_type = params

if single_value:
# optimization for a single value
Expand All @@ -212,11 +213,11 @@ def _delete_without_slice(params, obj, axis, single_value):
)
if obj < 0:
obj += n
a_shape[axis] -= 1
newshape[axis] -= 1
new = dpnp.empty(
a_shape,
dtype=a.dtype,
newshape,
order=order,
dtype=a.dtype,
sycl_queue=exec_q,
usm_type=usm_type,
)
Expand Down Expand Up @@ -265,18 +266,9 @@ def _calc_parameters(a, axis, obj, values=None):
n = a.shape[axis]
a_shape = list(a.shape)

if dpnp.is_supported_array_type(obj) and dpnp.is_supported_array_type(
values
):
usm_type, exec_q = get_usm_allocations([a, obj, values])
elif dpnp.is_supported_array_type(values):
usm_type, exec_q = get_usm_allocations([a, values])
elif dpnp.is_supported_array_type(obj):
usm_type, exec_q = get_usm_allocations([a, obj])
else:
usm_type, exec_q = a.usm_type, a.sycl_queue
usm_type, exec_q = get_usm_allocations([a, obj, values])

return Parameters(
return InsertDeleteParams(
a, a_ndim, order, axis, slobj, n, a_shape, exec_q, usm_type
)

Expand All @@ -288,7 +280,7 @@ def _insert_array_indices(parameters, indices, values, obj):

"""

a, a_ndim, order, axis, slobj, n, a_shape, exec_q, usm_type = parameters
a, a_ndim, order, axis, slobj, n, newshape, exec_q, usm_type = parameters

is_array = isinstance(obj, (dpnp_array, numpy.ndarray, dpt.usm_ndarray))
if indices.size == 0 and not is_array:
Expand All @@ -303,16 +295,16 @@ def _insert_array_indices(parameters, indices, values, obj):
numnew, dtype=indices.dtype, sycl_queue=exec_q, usm_type=usm_type
)

a_shape[axis] += numnew
newshape[axis] += numnew
old_mask = dpnp.ones(
a_shape[axis], dtype=dpnp.bool, sycl_queue=exec_q, usm_type=usm_type
newshape[axis], dtype=dpnp.bool, sycl_queue=exec_q, usm_type=usm_type
)
old_mask[indices] = False

new = dpnp.empty(
a_shape,
dtype=a.dtype,
newshape,
order=order,
dtype=a.dtype,
sycl_queue=exec_q,
usm_type=usm_type,
)
Expand All @@ -332,24 +324,27 @@ def _insert_singleton_index(parameters, indices, values, obj):

"""

a, a_ndim, order, axis, slobj, n, a_shape, exec_q, usm_type = parameters
a, a_ndim, order, axis, slobj, n, newshape, exec_q, usm_type = parameters

# In dpnp, `.item()` calls `.wait()`, so it is preferred to avoid it
# When possible (i.e. for numpy arrays, lists, etc), it is preferred
# to use `.item()` on a NumPy array
if isinstance(obj, (slice, dpnp_array, dpt.usm_ndarray)):
if isinstance(obj, (dpnp_array, dpt.usm_ndarray)):
index = indices.item()
else:
if isinstance(obj, slice):
obj = numpy.arange(*obj.indices(n), dtype=dpnp.intp)
index = numpy.asarray(obj).item()

if index < -n or index > n:
raise IndexError(
f"index {index} is out of bounds for axis {axis} " f"with size {n}"
f"index {index} is out of bounds for axis {axis} with size {n}"
)
if index < 0:
index += n

# There are some object array corner cases here, that cannot be avoided
# Need to change the dtype of values to input array dtype and update
# its shape to make ``input_arr[..., index, ...] = values`` legal
values = dpnp.array(
values,
copy=None,
Expand All @@ -362,15 +357,17 @@ def _insert_singleton_index(parameters, indices, values, obj):
# numpy.insert behave differently if obj is an scalar or an array
# with one element, so, this change is needed to align with NumPy
values = dpnp.moveaxis(values, 0, axis)

numnew = values.shape[axis]
a_shape[axis] += numnew
newshape[axis] += numnew
new = dpnp.empty(
a_shape,
dtype=a.dtype,
newshape,
order=order,
dtype=a.dtype,
sycl_queue=exec_q,
usm_type=usm_type,
)

slobj[axis] = slice(None, index)
new[tuple(slobj)] = a[tuple(slobj)]
slobj[axis] = slice(index, index + numnew)
Expand Down Expand Up @@ -2265,8 +2262,8 @@ def insert(arr, obj, values, axis=None):
)
else:
# need to copy obj, because indices will be changed in-place
indices = dpnp.array(
obj, copy=True, sycl_queue=params.exec_q, usm_type=params.usm_type
indices = dpnp.copy(
obj, sycl_queue=params.exec_q, usm_type=params.usm_type
)
if indices.dtype == dpnp.bool:
warnings.warn(
Expand Down
Loading
Loading