Skip to content

Commit 4975148

Browse files
Update _validate_interp_param()
1 parent 3c08905 commit 4975148

File tree

1 file changed

+9
-10
lines changed

1 file changed

+9
-10
lines changed

dpnp/dpnp_iface_mathematical.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,7 @@ def _validate_interp_param(param, name, exec_q, usm_type, dtype=None):
354354
"""
355355
Validate and convert optional parameters for interpolation.
356356
357-
Returns a USM array or None if the input is None.
357+
Returns a dpnp.ndarray or None if the input is None.
358358
"""
359359
if param is None:
360360
return None
@@ -372,10 +372,10 @@ def _validate_interp_param(param, name, exec_q, usm_type, dtype=None):
372372
)
373373
if dtype is not None:
374374
param = param.astype(dtype)
375-
return param.get_array()
375+
return param
376376

377377
if dpnp.isscalar(param):
378-
return dpt.asarray(
378+
return dpnp.asarray(
379379
param, dtype=dtype, sycl_queue=exec_q, usm_type=usm_type
380380
)
381381

@@ -2924,18 +2924,17 @@ def interp(x, xp, fp, left=None, right=None, period=None):
29242924
fp = dpnp.concatenate((fp[-1:], fp, fp[0:1]))
29252925

29262926
idx = dpnp.searchsorted(xp, x, side="right")
2927-
left_usm = _validate_interp_param(left, "left", exec_q, usm_type, fp.dtype)
2928-
right_usm = _validate_interp_param(
2929-
right, "right", exec_q, usm_type, fp.dtype
2930-
)
2927+
left = _validate_interp_param(left, "left", exec_q, usm_type, fp.dtype)
2928+
right = _validate_interp_param(right, "right", exec_q, usm_type, fp.dtype)
29312929

2932-
usm_type, exec_q = get_usm_allocations(
2933-
[x, xp, fp, period, left_usm, right_usm]
2934-
)
2930+
usm_type, exec_q = get_usm_allocations([x, xp, fp, period, left, right])
29352931
output = dpnp.empty(
29362932
x.shape, dtype=out_dtype, sycl_queue=exec_q, usm_type=usm_type
29372933
)
29382934

2935+
left_usm = left.get_array() if left is not None else None
2936+
right_usm = right.get_array() if right is not None else None
2937+
29392938
_manager = dpu.SequentialOrderManager[exec_q]
29402939
mem_ev, ht_ev = ufi._interpolate(
29412940
x.get_array(),

0 commit comments

Comments
 (0)