@@ -4168,6 +4168,65 @@ def test_matmul_with_offsets(self, sh1, sh2):
41684168 assert_array_equal (result , expected )
41694169
41704170
4171+ class TestMatmulInplace :
4172+ ALL_DTYPES = get_all_dtypes (no_none = True )
4173+ DTYPES = {}
4174+ for i in ALL_DTYPES :
4175+ for j in ALL_DTYPES :
4176+ if numpy .can_cast (j , i ):
4177+ DTYPES [f"{ i } -{ j } " ] = (i , j )
4178+
4179+ @pytest .mark .parametrize ("dtype1, dtype2" , DTYPES .values ())
4180+ def test_basic (self , dtype1 , dtype2 ):
4181+ a = numpy .arange (10 ).reshape (5 , 2 ).astype (dtype1 )
4182+ b = numpy .ones ((2 , 2 ), dtype = dtype2 )
4183+ ia , ib = dpnp .array (a ), dpnp .array (b )
4184+ ia_id = id (ia )
4185+
4186+ a @= b
4187+ ia @= ib
4188+ assert id (ia ) == ia_id
4189+ assert_dtype_allclose (ia , a )
4190+
4191+ @pytest .mark .parametrize (
4192+ "a_sh, b_sh" ,
4193+ [
4194+ pytest .param ((10 ** 5 , 10 ), (10 , 10 ), id = "2d_large" ),
4195+ pytest .param ((10 ** 4 , 10 , 10 ), (1 , 10 , 10 ), id = "3d_large" ),
4196+ pytest .param ((3 ,), (3 ,), id = "1d" ),
4197+ pytest .param ((3 , 3 ), (3 ,), id = "2d_1d" ),
4198+ pytest .param ((3 ,), (3 , 3 ), id = "1d_2d" ),
4199+ pytest .param ((3 , 3 ), (3 , 1 ), id = "2d_broadcast" ),
4200+ pytest .param ((1 , 3 ), (3 , 3 ), id = "2d_broadcast_reverse" ),
4201+ pytest .param ((3 , 3 , 3 ), (1 , 3 , 1 ), id = "3d_broadcast1" ),
4202+ pytest .param ((3 , 3 , 3 ), (1 , 3 , 3 ), id = "3d_broadcast2" ),
4203+ pytest .param ((3 , 3 , 3 ), (3 , 3 , 1 ), id = "3d_broadcast3" ),
4204+ pytest .param ((1 , 3 , 3 ), (3 , 3 , 3 ), id = "3d_broadcast_reverse1" ),
4205+ pytest .param ((3 , 1 , 3 ), (3 , 3 , 3 ), id = "3d_broadcast_reverse2" ),
4206+ pytest .param ((1 , 1 , 3 ), (3 , 3 , 3 ), id = "3d_broadcast_reverse3" ),
4207+ ],
4208+ )
4209+ def test_shapes (self , a_sh , b_sh ):
4210+ a_sz , b_sz = numpy .prod (a_sh ), numpy .prod (b_sh )
4211+ a = numpy .arange (a_sz ).reshape (a_sh ).astype (numpy .float64 )
4212+ b = numpy .arange (b_sz ).reshape (b_sh )
4213+
4214+ ia , ib = dpnp .array (a ), dpnp .array (b )
4215+ ia_id = id (ia )
4216+
4217+ expected = a @ b
4218+ if expected .shape != a_sh :
4219+ with pytest .raises (ValueError ):
4220+ a @= b
4221+
4222+ with pytest .raises (ValueError ):
4223+ ia @= ib
4224+ else :
4225+ ia @= ib
4226+ assert id (ia ) == ia_id
4227+ assert_dtype_allclose (ia , expected )
4228+
4229+
41714230class TestMatmulInvalidCases :
41724231 @pytest .mark .parametrize (
41734232 "shape_pair" ,
0 commit comments