@@ -629,50 +629,50 @@ def test_bitwise_op_2in(op, device):
629629 assert_sycl_queue_equal (zy .sycl_queue , y .sycl_queue )
630630
631631
632- @pytest .mark .parametrize ("device" , valid_dev , ids = dev_ids )
633- @pytest .mark .parametrize ("dtype" , [dpnp .int32 , dpnp .float32 ])
634- @pytest .mark .parametrize (
635- "shape1, shape2" ,
636- [
637- ((2 , 4 ), (4 ,)),
638- ((4 ,), (4 , 3 )),
639- ((2 , 4 ), (4 , 3 )),
640- ((2 , 0 ), (0 , 3 )),
641- ((2 , 4 ), (4 , 0 )),
642- ((4 , 2 , 3 ), (4 , 3 , 5 )),
643- ((4 , 2 , 3 ), (4 , 3 , 1 )),
644- ((4 , 1 , 3 ), (4 , 3 , 5 )),
645- ((6 , 7 , 4 , 3 ), (6 , 7 , 3 , 5 )),
646- ],
647- ids = [
648- "((2, 4), (4,))" ,
649- "((4,), (4, 3))" ,
650- "((2, 4), (4, 3))" ,
651- "((2, 0), (0, 3))" ,
652- "((2, 4), (4, 0))" ,
653- "((4, 2, 3), (4, 3, 5))" ,
654- "((4, 2, 3), (4, 3, 1))" ,
655- "((4, 1, 3), (4, 3, 5))" ,
656- "((6, 7, 4, 3), (6, 7, 3, 5))" ,
657- ],
658- )
659- def test_matmul (device , dtype , shape1 , shape2 ):
660- # int32 checks dpctl implementation and float32 checks oneMKL
661- a = dpnp .arange (numpy .prod (shape1 ), dtype = dtype , device = device )
662- b = dpnp .arange (numpy .prod (shape2 ), dtype = dtype , device = device )
663- a , b = a .reshape (shape1 ), b .reshape (shape2 )
664- result = dpnp .matmul (a , b )
665-
666- result_queue = result .sycl_queue
667- assert_sycl_queue_equal (result_queue , a .sycl_queue )
668- assert_sycl_queue_equal (result_queue , b .sycl_queue )
632+ class TestMatmul :
633+ @pytest .mark .parametrize ("device" , valid_dev , ids = dev_ids )
634+ @pytest .mark .parametrize ("dtype" , [dpnp .int32 , dpnp .float32 ])
635+ @pytest .mark .parametrize (
636+ "shape1, shape2" ,
637+ [
638+ ((2 , 4 ), (4 ,)),
639+ ((4 ,), (4 , 3 )),
640+ ((2 , 4 ), (4 , 3 )),
641+ ((2 , 0 ), (0 , 3 )),
642+ ((2 , 4 ), (4 , 0 )),
643+ ((4 , 2 , 3 ), (4 , 3 , 5 )),
644+ ((4 , 2 , 3 ), (4 , 3 , 1 )),
645+ ((4 , 1 , 3 ), (4 , 3 , 5 )),
646+ ((6 , 7 , 4 , 3 ), (6 , 7 , 3 , 5 )),
647+ ],
648+ ids = [
649+ "((2, 4), (4,))" ,
650+ "((4,), (4, 3))" ,
651+ "((2, 4), (4, 3))" ,
652+ "((2, 0), (0, 3))" ,
653+ "((2, 4), (4, 0))" ,
654+ "((4, 2, 3), (4, 3, 5))" ,
655+ "((4, 2, 3), (4, 3, 1))" ,
656+ "((4, 1, 3), (4, 3, 5))" ,
657+ "((6, 7, 4, 3), (6, 7, 3, 5))" ,
658+ ],
659+ )
660+ def test_matmul (self , device , dtype , shape1 , shape2 ):
661+ # int32 checks dpctl implementation and float32 checks oneMKL
662+ a = dpnp .arange (numpy .prod (shape1 ), dtype = dtype , device = device )
663+ b = dpnp .arange (numpy .prod (shape2 ), dtype = dtype , device = device )
664+ a , b = a .reshape (shape1 ), b .reshape (shape2 )
665+ result = dpnp .matmul (a , b )
669666
667+ result_queue = result .sycl_queue
668+ assert_sycl_queue_equal (result_queue , a .sycl_queue )
669+ assert_sycl_queue_equal (result_queue , b .sycl_queue )
670670
671- @pytest .mark .parametrize ("device" , valid_dev , ids = dev_ids )
672- def test_matmul_syrk (device ):
673- a = dpnp .arange (20 , dtype = dpnp .float32 , device = device ).reshape (4 , 5 )
674- result = dpnp .matmul (a , a .mT )
675- assert_sycl_queue_equal (result .sycl_queue , a .sycl_queue )
671+ @pytest .mark .parametrize ("device" , valid_dev , ids = dev_ids )
672+ def test_matmul_syrk (self , device ):
673+ a = dpnp .arange (20 , dtype = dpnp .float32 , device = device ).reshape (4 , 5 )
674+ result = dpnp .matmul (a , a .mT )
675+ assert_sycl_queue_equal (result .sycl_queue , a .sycl_queue )
676676
677677
678678@pytest .mark .parametrize ("device" , valid_dev , ids = dev_ids )
0 commit comments