4141import math
4242import operator
4343import warnings
44- from collections import namedtuple
44+ from typing import NamedTuple
4545
46+ import dpctl
4647import dpctl .tensor as dpt
4748import numpy
4849from dpctl .tensor ._numpy_helper import AxisError , normalize_axis_index
5556from .dpnp_utils import get_usm_allocations
5657from .dpnp_utils .dpnp_utils_pad import dpnp_pad
5758
58- Parameters = namedtuple (
59- "Parameters_insert_delete" ,
60- [
61- "a" ,
62- "a_ndim" ,
63- "order" ,
64- "axis" ,
65- "slobj" ,
66- "n" ,
67- "a_shape" ,
68- "exec_q" ,
69- "usm_type" ,
70- ],
71- )
59+
60+ class InsertDeleteParams ( NamedTuple ):
61+ """Parameters used for ``dpnp.delete`` and ``dpnp.insert``."""
62+
63+ a : dpnp_array
64+ a_ndim : int
65+ order : str
66+ axis : int
67+ slobj : list
68+ n : int
69+ a_shape : list
70+ exec_q : dpctl . SyclQueue
71+ usm_type : str
72+
7273
7374__all__ = [
7475 "append" ,
@@ -139,7 +140,7 @@ def _check_stack_arrays(arrays):
139140def _delete_with_slice (params , obj , axis ):
140141 """Utility function for ``dpnp.delete`` when obj is slice."""
141142
142- a , a_ndim , order , axis , slobj , n , a_shape , exec_q , usm_type = params
143+ a , a_ndim , order , axis , slobj , n , newshape , exec_q , usm_type = params
143144
144145 start , stop , step = obj .indices (n )
145146 xr = range (start , stop , step )
@@ -154,14 +155,8 @@ def _delete_with_slice(params, obj, axis):
154155 start = xr [- 1 ]
155156 stop = xr [0 ] + 1
156157
157- a_shape [axis ] -= num_del
158- new = dpnp .empty (
159- a_shape ,
160- dtype = a .dtype ,
161- order = order ,
162- sycl_queue = exec_q ,
163- usm_type = usm_type ,
164- )
158+ newshape [axis ] -= num_del
159+ new = dpnp .empty_like (a , order = order , shape = newshape )
165160 # copy initial chunk
166161 if start == 0 :
167162 pass
@@ -200,7 +195,7 @@ def _delete_with_slice(params, obj, axis):
200195def _delete_without_slice (params , obj , axis , single_value ):
201196 """Utility function for ``dpnp.delete`` when obj is int or array of int."""
202197
203- a , a_ndim , order , axis , slobj , n , a_shape , exec_q , usm_type = params
198+ a , a_ndim , order , axis , slobj , n , newshape , exec_q , usm_type = params
204199
205200 if single_value :
206201 # optimization for a single value
@@ -211,14 +206,8 @@ def _delete_without_slice(params, obj, axis, single_value):
211206 )
212207 if obj < 0 :
213208 obj += n
214- a_shape [axis ] -= 1
215- new = dpnp .empty (
216- a_shape ,
217- dtype = a .dtype ,
218- order = order ,
219- sycl_queue = exec_q ,
220- usm_type = usm_type ,
221- )
209+ newshape [axis ] -= 1
210+ new = dpnp .empty_like (a , order = order , shape = newshape )
222211 slobj [axis ] = slice (None , obj )
223212 new [tuple (slobj )] = a [tuple (slobj )]
224213 slobj [axis ] = slice (obj , None )
@@ -264,18 +253,9 @@ def _calc_parameters(a, axis, obj, values=None):
264253 n = a .shape [axis ]
265254 a_shape = list (a .shape )
266255
267- if dpnp .is_supported_array_type (obj ) and dpnp .is_supported_array_type (
268- values
269- ):
270- usm_type , exec_q = get_usm_allocations ([a , obj , values ])
271- elif dpnp .is_supported_array_type (values ):
272- usm_type , exec_q = get_usm_allocations ([a , values ])
273- elif dpnp .is_supported_array_type (obj ):
274- usm_type , exec_q = get_usm_allocations ([a , obj ])
275- else :
276- usm_type , exec_q = a .usm_type , a .sycl_queue
256+ usm_type , exec_q = get_usm_allocations ([a , obj , values ])
277257
278- return Parameters (
258+ return InsertDeleteParams (
279259 a , a_ndim , order , axis , slobj , n , a_shape , exec_q , usm_type
280260 )
281261
@@ -287,7 +267,7 @@ def _insert_array_indices(parameters, indices, values, obj):
287267
288268 """
289269
290- a , a_ndim , order , axis , slobj , n , a_shape , exec_q , usm_type = parameters
270+ a , a_ndim , order , axis , slobj , n , newshape , exec_q , usm_type = parameters
291271
292272 is_array = isinstance (obj , (dpnp_array , numpy .ndarray , dpt .usm_ndarray ))
293273 if indices .size == 0 and not is_array :
@@ -302,19 +282,13 @@ def _insert_array_indices(parameters, indices, values, obj):
302282 numnew , dtype = indices .dtype , sycl_queue = exec_q , usm_type = usm_type
303283 )
304284
305- a_shape [axis ] += numnew
285+ newshape [axis ] += numnew
306286 old_mask = dpnp .ones (
307- a_shape [axis ], dtype = dpnp .bool , sycl_queue = exec_q , usm_type = usm_type
287+ newshape [axis ], dtype = dpnp .bool , sycl_queue = exec_q , usm_type = usm_type
308288 )
309289 old_mask [indices ] = False
310290
311- new = dpnp .empty (
312- a_shape ,
313- dtype = a .dtype ,
314- order = order ,
315- sycl_queue = exec_q ,
316- usm_type = usm_type ,
317- )
291+ new = dpnp .empty_like (a , order = order , shape = newshape )
318292 slobj2 = [slice (None )] * a_ndim
319293 slobj [axis ] = indices
320294 slobj2 [axis ] = old_mask
@@ -331,24 +305,27 @@ def _insert_singleton_index(parameters, indices, values, obj):
331305
332306 """
333307
334- a , a_ndim , order , axis , slobj , n , a_shape , exec_q , usm_type = parameters
308+ a , a_ndim , order , axis , slobj , n , newshape , exec_q , usm_type = parameters
335309
336310 # In dpnp, `.item()` calls `.wait()`, so it is preferred to avoid it
337311 # When possible (i.e. for numpy arrays, lists, etc), it is preferred
338312 # to use `.item()` on a NumPy array
339- if isinstance (obj , (slice , dpnp_array , dpt .usm_ndarray )):
313+ if isinstance (obj , (dpnp_array , dpt .usm_ndarray )):
340314 index = indices .item ()
341315 else :
316+ if isinstance (obj , slice ):
317+ obj = numpy .arange (* obj .indices (n ), dtype = dpnp .intp )
342318 index = numpy .asarray (obj ).item ()
343319
344320 if index < - n or index > n :
345321 raise IndexError (
346- f"index { index } is out of bounds for axis { axis } " f" with size { n } "
322+ f"index { index } is out of bounds for axis { axis } with size { n } "
347323 )
348324 if index < 0 :
349325 index += n
350326
351- # There are some object array corner cases here, that cannot be avoided
327+ # Need to change the dtype of values to input array dtype and update
328+ # its shape to make ``input_arr[..., index, ...] = values`` legal
352329 values = dpnp .array (
353330 values ,
354331 copy = None ,
@@ -361,15 +338,11 @@ def _insert_singleton_index(parameters, indices, values, obj):
361338 # numpy.insert behave differently if obj is an scalar or an array
362339 # with one element, so, this change is needed to align with NumPy
363340 values = dpnp .moveaxis (values , 0 , axis )
341+
364342 numnew = values .shape [axis ]
365- a_shape [axis ] += numnew
366- new = dpnp .empty (
367- a_shape ,
368- dtype = a .dtype ,
369- order = order ,
370- sycl_queue = exec_q ,
371- usm_type = usm_type ,
372- )
343+ newshape [axis ] += numnew
344+ new = dpnp .empty_like (a , order = order , shape = newshape )
345+
373346 slobj [axis ] = slice (None , index )
374347 new [tuple (slobj )] = a [tuple (slobj )]
375348 slobj [axis ] = slice (index , index + numnew )
@@ -2229,8 +2202,8 @@ def insert(arr, obj, values, axis=None):
22292202 )
22302203 else :
22312204 # need to copy obj, because indices will be changed in-place
2232- indices = dpnp .array (
2233- obj , copy = True , sycl_queue = params .exec_q , usm_type = params .usm_type
2205+ indices = dpnp .copy (
2206+ obj , sycl_queue = params .exec_q , usm_type = params .usm_type
22342207 )
22352208 if indices .dtype == dpnp .bool :
22362209 warnings .warn (
0 commit comments