@@ -2227,19 +2227,37 @@ def test_complex(self, func, xp, dt):
22272227
22282228 @testing .with_requires ("numpy>=2.1.0" )
22292229 @pytest .mark .parametrize (
2230- "dt " , get_all_dtypes (no_none = True , no_complex = True )
2230+ "dt_in " , get_all_dtypes (no_none = True , no_complex = True )
22312231 )
2232- def test_out (self , func , dt ):
2233- a = generate_random_numpy_array (10 , dt )
2234- # TODO: use dt_out = dt when dpctl#2030 is fixed
2235- dt_out = numpy .int8 if dt == dpnp .bool else dt
2232+ @pytest .mark .parametrize (
2233+ "dt_out" , get_all_dtypes (no_none = True , no_complex = True )
2234+ )
2235+ def test_out (self , func , dt_in , dt_out ):
2236+ a = generate_random_numpy_array (10 , dt_in )
22362237 out = numpy .empty (a .shape , dtype = dt_out )
22372238 ia , iout = dpnp .array (a ), dpnp .array (out )
22382239
2239- expected = getattr (numpy , func )(a , out = out )
2240- result = getattr (dpnp , func )(ia , out = iout )
2241- assert result is iout
2242- assert_array_equal (result , expected )
2240+ if dt_in != dt_out :
2241+ if numpy .can_cast (dt_in , dt_out , casting = "same_kind" ):
2242+ # NumPy allows "same_kind" casting, dpnp does not
2243+ if func != "fix" and dt_in == dpnp .bool and dt_out == dpnp .int8 :
2244+ # TODO: get rid of w/a when dpctl#2030 is fixed
2245+ pass
2246+ else :
2247+ assert_raises (ValueError , getattr (dpnp , func ), ia , out = iout )
2248+ else :
2249+ assert_raises (ValueError , getattr (dpnp , func ), ia , out = iout )
2250+ assert_raises (TypeError , getattr (numpy , func ), a , out = out )
2251+ else :
2252+ if func != "fix" and dt_in == dpnp .bool :
2253+ # TODO: get rid of w/a when dpctl#2030 is fixed
2254+ out = out .astype (numpy .int8 )
2255+ iout = iout .astype (dpnp .int8 )
2256+
2257+ expected = getattr (numpy , func )(a , out = out )
2258+ result = getattr (dpnp , func )(ia , out = iout )
2259+ assert result is iout
2260+ assert_array_equal (result , expected )
22432261
22442262 @pytest .mark .skipif (not has_support_aspect16 (), reason = "no fp16 support" )
22452263 def test_out_float16 (self , func ):
@@ -2252,22 +2270,22 @@ def test_out_float16(self, func):
22522270 assert result is iout
22532271 assert_array_equal (result , expected )
22542272
2255- @pytest .mark .parametrize ("xp" , [numpy , dpnp ])
22562273 @pytest .mark .parametrize (
2257- "dt_out " , get_all_dtypes (no_none = True , no_complex = True )[: - 1 ]
2274+ "dt " , get_all_dtypes (no_none = True , no_complex = True )
22582275 )
2259- def test_invalid_dtype (self , func , xp , dt_out ):
2260- dt_in = get_all_dtypes (no_none = True , no_complex = True )[- 1 ]
2261- a = xp .arange (10 , dtype = dt_in )
2262- out = xp .empty (10 , dtype = dt_out )
2263- if dt_out == numpy .float32 and dt_in == numpy .float64 :
2264- if xp == dpnp :
2265- # NumPy allows "same_kind" casting, dpnp does not
2266- assert_raises (ValueError , getattr (dpnp , func ), a , out = out )
2267- else :
2268- assert_raises (
2269- (ValueError , TypeError ), getattr (xp , func ), a , out = out
2270- )
2276+ def test_out_usm_ndarray (self , func , dt ):
2277+ a = generate_random_numpy_array (10 , dt )
2278+ out = numpy .empty (a .shape , dtype = dt )
2279+ ia , usm_out = dpnp .array (a ), dpt .asarray (out )
2280+
2281+ if func != "fix" and dt == dpnp .bool :
2282+ # TODO: get rid of w/a when dpctl#2030 is fixed
2283+ out = out .astype (numpy .int8 )
2284+ usm_out = dpt .asarray (usm_out , dtype = dpnp .int8 )
2285+
2286+ expected = getattr (numpy , func )(a , out = out )
2287+ result = getattr (dpnp , func )(ia , out = usm_out )
2288+ assert_array_equal (result , expected )
22712289
22722290 @pytest .mark .parametrize ("xp" , [numpy , dpnp ])
22732291 @pytest .mark .parametrize (
@@ -2278,9 +2296,15 @@ def test_invalid_shape(self, func, xp, shape):
22782296 out = xp .empty (shape , dtype = xp .float32 )
22792297 assert_raises (ValueError , getattr (xp , func ), a , out = out )
22802298
2281- def test_scalar (self , func ):
2299+ def test_error (self , func ):
2300+ # scalar, unsupported input
22822301 assert_raises (TypeError , getattr (dpnp , func ), - 3.4 )
22832302
2303+ # unsupported out
2304+ a = dpnp .array ([1 , 2 , 3 ])
2305+ out = numpy .empty_like (3 , dtype = a .dtype )
2306+ assert_raises (TypeError , getattr (dpnp , func ), a , out = out )
2307+
22842308
22852309class TestHypot :
22862310 @pytest .mark .parametrize (
0 commit comments