Skip to content

Commit aa102bd

Browse files
Add usm_type tests for interp()
1 parent cbe7e7a commit aa102bd

File tree

2 files changed

+66
-3
lines changed

2 files changed

+66
-3
lines changed

dpnp/dpnp_iface_mathematical.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2921,15 +2921,19 @@ def interp(x, xp, fp, left=None, right=None, period=None):
29212921
assert xp.flags.c_contiguous
29222922
assert fp.flags.c_contiguous
29232923

2924-
output = dpnp.empty(
2925-
x.shape, dtype=out_dtype, sycl_queue=exec_q, usm_type=usm_type
2926-
)
29272924
idx = dpnp.searchsorted(xp, x, side="right")
29282925
left_usm = _validate_interp_param(left, "left", exec_q, usm_type, fp.dtype)
29292926
right_usm = _validate_interp_param(
29302927
right, "right", exec_q, usm_type, fp.dtype
29312928
)
29322929

2930+
usm_type, exec_q = get_usm_allocations(
2931+
[x, xp, fp, period, left_usm, right_usm]
2932+
)
2933+
output = dpnp.empty(
2934+
x.shape, dtype=out_dtype, sycl_queue=exec_q, usm_type=usm_type
2935+
)
2936+
29332937
_manager = dpu.SequentialOrderManager[exec_q]
29342938
mem_ev, ht_ev = ufi._interpolate(
29352939
x.get_array(),

dpnp/tests/test_usm_type.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1268,6 +1268,65 @@ def test_choose(usm_type_x, usm_type_ind):
12681268
assert z.usm_type == du.get_coerced_usm_type([usm_type_x, usm_type_ind])
12691269

12701270

1271+
class TestInterp:
1272+
@pytest.mark.parametrize("usm_type_x", list_of_usm_types)
1273+
@pytest.mark.parametrize("usm_type_xp", list_of_usm_types)
1274+
@pytest.mark.parametrize("usm_type_fp", list_of_usm_types)
1275+
def test_basic(self, usm_type_x, usm_type_xp, usm_type_fp):
1276+
x = dpnp.linspace(0.1, 9.9, 20, usm_type=usm_type_x)
1277+
xp = dpnp.linspace(0.0, 10.0, 5, usm_type=usm_type_xp)
1278+
fp = dpnp.array(xp * 2 + 1, usm_type=usm_type_fp)
1279+
1280+
result = dpnp.interp(x, xp, fp)
1281+
1282+
assert x.usm_type == usm_type_x
1283+
assert xp.usm_type == usm_type_xp
1284+
assert fp.usm_type == usm_type_fp
1285+
assert result.usm_type == du.get_coerced_usm_type(
1286+
[usm_type_x, usm_type_xp, usm_type_fp]
1287+
)
1288+
1289+
@pytest.mark.parametrize("usm_type_x", list_of_usm_types)
1290+
@pytest.mark.parametrize("usm_type_left", list_of_usm_types)
1291+
@pytest.mark.parametrize("usm_type_right", list_of_usm_types)
1292+
def test_left_right(self, usm_type_x, usm_type_left, usm_type_right):
1293+
x = dpnp.linspace(-1.0, 11.0, 5, usm_type=usm_type_x)
1294+
xp = dpnp.linspace(0.0, 10.0, 5, usm_type=usm_type_x)
1295+
fp = dpnp.array(xp * 2 + 1, usm_type=usm_type_x)
1296+
1297+
left = dpnp.array(-100, usm_type=usm_type_left)
1298+
right = dpnp.array(100, usm_type=usm_type_right)
1299+
1300+
result = dpnp.interp(x, xp, fp, left=left, right=right)
1301+
1302+
assert left.usm_type == usm_type_left
1303+
assert right.usm_type == usm_type_right
1304+
assert result.usm_type == du.get_coerced_usm_type(
1305+
[
1306+
x.usm_type,
1307+
xp.usm_type,
1308+
fp.usm_type,
1309+
left.usm_type,
1310+
right.usm_type,
1311+
]
1312+
)
1313+
1314+
@pytest.mark.parametrize("usm_type_x", list_of_usm_types)
1315+
@pytest.mark.parametrize("usm_type_period", list_of_usm_types)
1316+
def test_period(self, usm_type_x, usm_type_period):
1317+
x = dpnp.linspace(0.1, 9.9, 20, usm_type=usm_type_x)
1318+
xp = dpnp.linspace(0.0, 10.0, 5, usm_type=usm_type_x)
1319+
fp = dpnp.array(xp * 2 + 1, usm_type=usm_type_x)
1320+
period = dpnp.array(10.0, usm_type=usm_type_period)
1321+
1322+
result = dpnp.interp(x, xp, fp, period=period)
1323+
1324+
assert period.usm_type == usm_type_period
1325+
assert result.usm_type == du.get_coerced_usm_type(
1326+
[x.usm_type, xp.usm_type, fp.usm_type, period.usm_type]
1327+
)
1328+
1329+
12711330
@pytest.mark.parametrize("usm_type", list_of_usm_types)
12721331
class TestLinAlgebra:
12731332
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)