@@ -4228,24 +4228,25 @@ def test_shapes(self, a_sh, b_sh):
42284228
42294229
42304230class TestMatmulInvalidCases :
4231+ @pytest .mark .parametrize ("xp" , [numpy , dpnp ])
42314232 @pytest .mark .parametrize (
4232- "shape_pair " ,
4233+ "shape1, shape2 " ,
42334234 [
42344235 ((3 , 2 ), ()),
42354236 ((), (3 , 2 )),
42364237 ((), ()),
42374238 ],
42384239 )
4239- def test_zero_dim (self , shape_pair ):
4240- for xp in (numpy , dpnp ):
4241- shape1 , shape2 = shape_pair
4242- x1 = xp .arange (numpy .prod (shape1 ), dtype = xp .float32 ).reshape (shape1 )
4243- x2 = xp .arange (numpy .prod (shape2 ), dtype = xp .float32 ).reshape (shape2 )
4244- with pytest .raises (ValueError ):
4245- xp .matmul (x1 , x2 )
4240+ def test_zero_dim (self , xp , shape1 , shape2 ):
4241+ x1 = xp .arange (numpy .prod (shape1 ), dtype = xp .float32 ).reshape (shape1 )
4242+ x2 = xp .arange (numpy .prod (shape2 ), dtype = xp .float32 ).reshape (shape2 )
42464243
4244+ with pytest .raises (ValueError ):
4245+ xp .matmul (x1 , x2 )
4246+
4247+ @pytest .mark .parametrize ("xp" , [numpy , dpnp ])
42474248 @pytest .mark .parametrize (
4248- "shape_pair " ,
4249+ "shape1, shape2 " ,
42494250 [
42504251 ((3 ,), (4 ,)),
42514252 ((2 , 3 ), (4 , 5 )),
@@ -4258,16 +4259,16 @@ def test_zero_dim(self, shape_pair):
42584259 ((6 , 5 , 3 , 2 ), (3 , 2 , 4 )),
42594260 ],
42604261 )
4261- def test_invalid_shape (self , shape_pair ):
4262- for xp in (numpy , dpnp ):
4263- shape1 , shape2 = shape_pair
4264- x1 = xp .arange (numpy .prod (shape1 ), dtype = xp .float32 ).reshape (shape1 )
4265- x2 = xp .arange (numpy .prod (shape2 ), dtype = xp .float32 ).reshape (shape2 )
4266- with pytest .raises (ValueError ):
4267- xp .matmul (x1 , x2 )
4262+ def test_invalid_shape (self , xp , shape1 , shape2 ):
4263+ x1 = xp .arange (numpy .prod (shape1 ), dtype = xp .float32 ).reshape (shape1 )
4264+ x2 = xp .arange (numpy .prod (shape2 ), dtype = xp .float32 ).reshape (shape2 )
42684265
4266+ with pytest .raises (ValueError ):
4267+ xp .matmul (x1 , x2 )
4268+
4269+ @pytest .mark .parametrize ("xp" , [numpy , dpnp ])
42694270 @pytest .mark .parametrize (
4270- "shape_pair " ,
4271+ "shape1, shape2, out_shape " ,
42714272 [
42724273 ((5 , 4 , 3 ), (3 , 1 ), (3 , 4 , 1 )),
42734274 ((5 , 4 , 3 ), (3 , 1 ), (5 , 6 , 1 )),
@@ -4279,24 +4280,24 @@ def test_invalid_shape(self, shape_pair):
42794280 ((4 ,), (3 , 4 , 5 ), (3 , 6 )),
42804281 ],
42814282 )
4282- def test_invalid_shape_out (self , shape_pair ):
4283- for xp in (numpy , dpnp ):
4284- shape1 , shape2 , out_shape = shape_pair
4285- x1 = xp .arange (numpy .prod (shape1 ), dtype = xp .float32 ).reshape (shape1 )
4286- x2 = xp .arange (numpy .prod (shape2 ), dtype = xp .float32 ).reshape (shape2 )
4287- res = xp .empty (out_shape )
4288- with pytest .raises (ValueError ):
4289- xp .matmul (x1 , x2 , out = res )
4283+ def test_invalid_shape_out (self , xp , shape1 , shape2 , out_shape ):
4284+ x1 = xp .arange (numpy .prod (shape1 ), dtype = xp .float32 ).reshape (shape1 )
4285+ x2 = xp .arange (numpy .prod (shape2 ), dtype = xp .float32 ).reshape (shape2 )
4286+ res = xp .empty (out_shape )
42904287
4288+ with pytest .raises (ValueError ):
4289+ xp .matmul (x1 , x2 , out = res )
4290+
4291+ @pytest .mark .parametrize ("xp" , [numpy , dpnp ])
42914292 @pytest .mark .parametrize ("dtype" , get_all_dtypes (no_none = True )[:- 2 ])
4292- def test_invalid_dtype (self , dtype ):
4293+ def test_invalid_dtype (self , xp , dtype ):
42934294 dpnp_dtype = get_all_dtypes (no_none = True )[- 1 ]
4294- a1 = dpnp .arange (5 * 4 , dtype = dpnp_dtype ).reshape (5 , 4 )
4295- a2 = dpnp .arange (7 * 4 , dtype = dpnp_dtype ).reshape (4 , 7 )
4296- dp_out = dpnp .empty ((5 , 7 ), dtype = dtype )
4295+ a1 = xp .arange (5 * 4 , dtype = dpnp_dtype ).reshape (5 , 4 )
4296+ a2 = xp .arange (7 * 4 , dtype = dpnp_dtype ).reshape (4 , 7 )
4297+ dp_out = xp .empty ((5 , 7 ), dtype = dtype )
42974298
42984299 with pytest .raises (TypeError ):
4299- dpnp .matmul (a1 , a2 , out = dp_out )
4300+ xp .matmul (a1 , a2 , out = dp_out )
43004301
43014302 def test_exe_q (self ):
43024303 x1 = dpnp .ones ((5 , 4 ), sycl_queue = dpctl .SyclQueue ())
@@ -4310,13 +4311,14 @@ def test_exe_q(self):
43104311 with pytest .raises (ExecutionPlacementError ):
43114312 dpnp .matmul (x1 , x2 , out = out )
43124313
4313- def test_matmul_casting (self ):
4314- a1 = dpnp .arange (2 * 4 , dtype = dpnp .float32 ).reshape (2 , 4 )
4315- a2 = dpnp .arange (4 * 3 ).reshape (4 , 3 )
4314+ @pytest .mark .parametrize ("xp" , [numpy , dpnp ])
4315+ def test_matmul_casting (self , xp ):
4316+ a1 = xp .arange (2 * 4 , dtype = xp .float32 ).reshape (2 , 4 )
4317+ a2 = xp .arange (4 * 3 ).reshape (4 , 3 )
43164318
4317- res = dpnp .empty ((2 , 3 ), dtype = dpnp .int64 )
4319+ res = xp .empty ((2 , 3 ), dtype = xp .int64 )
43184320 with pytest .raises (TypeError ):
4319- dpnp .matmul (a1 , a2 , out = res , casting = "safe" )
4321+ xp .matmul (a1 , a2 , out = res , casting = "safe" )
43204322
43214323 def test_matmul_not_implemented (self ):
43224324 a1 = dpnp .arange (2 * 4 ).reshape (2 , 4 )
@@ -4332,52 +4334,53 @@ def test_matmul_not_implemented(self):
43324334 with pytest .raises (NotImplementedError ):
43334335 dpnp .matmul (a1 , a2 , axis = 2 )
43344336
4335- def test_matmul_axes (self ):
4336- a1 = dpnp .arange (120 ).reshape (2 , 5 , 3 , 4 )
4337- a2 = dpnp .arange (120 ).reshape (4 , 2 , 5 , 3 )
4337+ @pytest .mark .parametrize ("xp" , [numpy , dpnp ])
4338+ def test_matmul_axes (self , xp ):
4339+ a1 = xp .arange (120 ).reshape (2 , 5 , 3 , 4 )
4340+ a2 = xp .arange (120 ).reshape (4 , 2 , 5 , 3 )
43384341
43394342 # axes must be a list
43404343 axes = ((3 , 1 ), (2 , 0 ), (0 , 1 ))
43414344 with pytest .raises (TypeError ):
4342- dpnp .matmul (a1 , a2 , axes = axes )
4345+ xp .matmul (a1 , a2 , axes = axes )
43434346
43444347 # axes must be be a list of three tuples
43454348 axes = [(3 , 1 ), (2 , 0 )]
43464349 with pytest .raises (ValueError ):
4347- dpnp .matmul (a1 , a2 , axes = axes )
4350+ xp .matmul (a1 , a2 , axes = axes )
43484351
43494352 # axes item should be a tuple
43504353 axes = [(3 , 1 ), (2 , 0 ), [0 , 1 ]]
43514354 with pytest .raises (TypeError ):
4352- dpnp .matmul (a1 , a2 , axes = axes )
4355+ xp .matmul (a1 , a2 , axes = axes )
43534356
43544357 # axes item should be a tuple with 2 elements
43554358 axes = [(3 , 1 ), (2 , 0 ), (0 , 1 , 2 )]
43564359 with pytest .raises (ValueError ):
4357- dpnp .matmul (a1 , a2 , axes = axes )
4360+ xp .matmul (a1 , a2 , axes = axes )
43584361
43594362 # axes must be an integer
43604363 axes = [(3 , 1 ), (2 , 0 ), (0.0 , 1 )]
43614364 with pytest .raises (TypeError ):
4362- dpnp .matmul (a1 , a2 , axes = axes )
4365+ xp .matmul (a1 , a2 , axes = axes )
43634366
43644367 # axes item 2 should be an empty tuple
4365- a = dpnp .arange (3 )
4368+ a = xp .arange (3 )
43664369 axes = [0 , 0 , 0 ]
4367- with pytest .raises (TypeError ):
4368- dpnp .matmul (a , a , axes = axes )
4370+ with pytest .raises (ValueError ):
4371+ xp .matmul (a , a , axes = axes )
43694372
4370- a = dpnp .arange (3 * 4 * 5 ).reshape (3 , 4 , 5 )
4371- b = dpnp .arange (3 )
4373+ a = xp .arange (3 * 4 * 5 ).reshape (3 , 4 , 5 )
4374+ b = xp .arange (3 )
43724375 # list object cannot be interpreted as an integer
43734376 axes = [(1 , 0 ), (0 ), [0 ]]
43744377 with pytest .raises (TypeError ):
4375- dpnp .matmul (a , b , axes = axes )
4378+ xp .matmul (a , b , axes = axes )
43764379
43774380 # axes item should be a tuple with a single element, or an integer
43784381 axes = [(1 , 0 ), (0 ), (0 , 1 )]
43794382 with pytest .raises (ValueError ):
4380- dpnp .matmul (a , b , axes = axes )
4383+ xp .matmul (a , b , axes = axes )
43814384
43824385
43834386def test_elemenwise_nin_nout ():
0 commit comments