@@ -1268,6 +1268,65 @@ def test_choose(usm_type_x, usm_type_ind):
12681268 assert z .usm_type == du .get_coerced_usm_type ([usm_type_x , usm_type_ind ])
12691269
12701270
1271+ class TestInterp :
1272+ @pytest .mark .parametrize ("usm_type_x" , list_of_usm_types )
1273+ @pytest .mark .parametrize ("usm_type_xp" , list_of_usm_types )
1274+ @pytest .mark .parametrize ("usm_type_fp" , list_of_usm_types )
1275+ def test_basic (self , usm_type_x , usm_type_xp , usm_type_fp ):
1276+ x = dpnp .linspace (0.1 , 9.9 , 20 , usm_type = usm_type_x )
1277+ xp = dpnp .linspace (0.0 , 10.0 , 5 , usm_type = usm_type_xp )
1278+ fp = dpnp .array (xp * 2 + 1 , usm_type = usm_type_fp )
1279+
1280+ result = dpnp .interp (x , xp , fp )
1281+
1282+ assert x .usm_type == usm_type_x
1283+ assert xp .usm_type == usm_type_xp
1284+ assert fp .usm_type == usm_type_fp
1285+ assert result .usm_type == du .get_coerced_usm_type (
1286+ [usm_type_x , usm_type_xp , usm_type_fp ]
1287+ )
1288+
1289+ @pytest .mark .parametrize ("usm_type_x" , list_of_usm_types )
1290+ @pytest .mark .parametrize ("usm_type_left" , list_of_usm_types )
1291+ @pytest .mark .parametrize ("usm_type_right" , list_of_usm_types )
1292+ def test_left_right (self , usm_type_x , usm_type_left , usm_type_right ):
1293+ x = dpnp .linspace (- 1.0 , 11.0 , 5 , usm_type = usm_type_x )
1294+ xp = dpnp .linspace (0.0 , 10.0 , 5 , usm_type = usm_type_x )
1295+ fp = dpnp .array (xp * 2 + 1 , usm_type = usm_type_x )
1296+
1297+ left = dpnp .array (- 100 , usm_type = usm_type_left )
1298+ right = dpnp .array (100 , usm_type = usm_type_right )
1299+
1300+ result = dpnp .interp (x , xp , fp , left = left , right = right )
1301+
1302+ assert left .usm_type == usm_type_left
1303+ assert right .usm_type == usm_type_right
1304+ assert result .usm_type == du .get_coerced_usm_type (
1305+ [
1306+ x .usm_type ,
1307+ xp .usm_type ,
1308+ fp .usm_type ,
1309+ left .usm_type ,
1310+ right .usm_type ,
1311+ ]
1312+ )
1313+
1314+ @pytest .mark .parametrize ("usm_type_x" , list_of_usm_types )
1315+ @pytest .mark .parametrize ("usm_type_period" , list_of_usm_types )
1316+ def test_period (self , usm_type_x , usm_type_period ):
1317+ x = dpnp .linspace (0.1 , 9.9 , 20 , usm_type = usm_type_x )
1318+ xp = dpnp .linspace (0.0 , 10.0 , 5 , usm_type = usm_type_x )
1319+ fp = dpnp .array (xp * 2 + 1 , usm_type = usm_type_x )
1320+ period = dpnp .array (10.0 , usm_type = usm_type_period )
1321+
1322+ result = dpnp .interp (x , xp , fp , period = period )
1323+
1324+ assert period .usm_type == usm_type_period
1325+ assert result .usm_type == du .get_coerced_usm_type (
1326+ [x .usm_type , xp .usm_type , fp .usm_type , period .usm_type ]
1327+ )
1328+
1329+
12711330@pytest .mark .parametrize ("usm_type" , list_of_usm_types )
12721331class TestLinAlgebra :
12731332 @pytest .mark .parametrize (
0 commit comments