@@ -629,50 +629,50 @@ def test_bitwise_op_2in(op, device):
629
629
assert_sycl_queue_equal (zy .sycl_queue , y .sycl_queue )
630
630
631
631
632
- @pytest .mark .parametrize ("device" , valid_dev , ids = dev_ids )
633
- @pytest .mark .parametrize ("dtype" , [dpnp .int32 , dpnp .float32 ])
634
- @pytest .mark .parametrize (
635
- "shape1, shape2" ,
636
- [
637
- ((2 , 4 ), (4 ,)),
638
- ((4 ,), (4 , 3 )),
639
- ((2 , 4 ), (4 , 3 )),
640
- ((2 , 0 ), (0 , 3 )),
641
- ((2 , 4 ), (4 , 0 )),
642
- ((4 , 2 , 3 ), (4 , 3 , 5 )),
643
- ((4 , 2 , 3 ), (4 , 3 , 1 )),
644
- ((4 , 1 , 3 ), (4 , 3 , 5 )),
645
- ((6 , 7 , 4 , 3 ), (6 , 7 , 3 , 5 )),
646
- ],
647
- ids = [
648
- "((2, 4), (4,))" ,
649
- "((4,), (4, 3))" ,
650
- "((2, 4), (4, 3))" ,
651
- "((2, 0), (0, 3))" ,
652
- "((2, 4), (4, 0))" ,
653
- "((4, 2, 3), (4, 3, 5))" ,
654
- "((4, 2, 3), (4, 3, 1))" ,
655
- "((4, 1, 3), (4, 3, 5))" ,
656
- "((6, 7, 4, 3), (6, 7, 3, 5))" ,
657
- ],
658
- )
659
- def test_matmul (device , dtype , shape1 , shape2 ):
660
- # int32 checks dpctl implementation and float32 checks oneMKL
661
- a = dpnp .arange (numpy .prod (shape1 ), dtype = dtype , device = device )
662
- b = dpnp .arange (numpy .prod (shape2 ), dtype = dtype , device = device )
663
- a , b = a .reshape (shape1 ), b .reshape (shape2 )
664
- result = dpnp .matmul (a , b )
665
-
666
- result_queue = result .sycl_queue
667
- assert_sycl_queue_equal (result_queue , a .sycl_queue )
668
- assert_sycl_queue_equal (result_queue , b .sycl_queue )
632
+ class TestMatmul :
633
+ @pytest .mark .parametrize ("device" , valid_dev , ids = dev_ids )
634
+ @pytest .mark .parametrize ("dtype" , [dpnp .int32 , dpnp .float32 ])
635
+ @pytest .mark .parametrize (
636
+ "shape1, shape2" ,
637
+ [
638
+ ((2 , 4 ), (4 ,)),
639
+ ((4 ,), (4 , 3 )),
640
+ ((2 , 4 ), (4 , 3 )),
641
+ ((2 , 0 ), (0 , 3 )),
642
+ ((2 , 4 ), (4 , 0 )),
643
+ ((4 , 2 , 3 ), (4 , 3 , 5 )),
644
+ ((4 , 2 , 3 ), (4 , 3 , 1 )),
645
+ ((4 , 1 , 3 ), (4 , 3 , 5 )),
646
+ ((6 , 7 , 4 , 3 ), (6 , 7 , 3 , 5 )),
647
+ ],
648
+ ids = [
649
+ "((2, 4), (4,))" ,
650
+ "((4,), (4, 3))" ,
651
+ "((2, 4), (4, 3))" ,
652
+ "((2, 0), (0, 3))" ,
653
+ "((2, 4), (4, 0))" ,
654
+ "((4, 2, 3), (4, 3, 5))" ,
655
+ "((4, 2, 3), (4, 3, 1))" ,
656
+ "((4, 1, 3), (4, 3, 5))" ,
657
+ "((6, 7, 4, 3), (6, 7, 3, 5))" ,
658
+ ],
659
+ )
660
+ def test_matmul (self , device , dtype , shape1 , shape2 ):
661
+ # int32 checks dpctl implementation and float32 checks oneMKL
662
+ a = dpnp .arange (numpy .prod (shape1 ), dtype = dtype , device = device )
663
+ b = dpnp .arange (numpy .prod (shape2 ), dtype = dtype , device = device )
664
+ a , b = a .reshape (shape1 ), b .reshape (shape2 )
665
+ result = dpnp .matmul (a , b )
669
666
667
+ result_queue = result .sycl_queue
668
+ assert_sycl_queue_equal (result_queue , a .sycl_queue )
669
+ assert_sycl_queue_equal (result_queue , b .sycl_queue )
670
670
671
- @pytest .mark .parametrize ("device" , valid_dev , ids = dev_ids )
672
- def test_matmul_syrk (device ):
673
- a = dpnp .arange (20 , dtype = dpnp .float32 , device = device ).reshape (4 , 5 )
674
- result = dpnp .matmul (a , a .mT )
675
- assert_sycl_queue_equal (result .sycl_queue , a .sycl_queue )
671
+ @pytest .mark .parametrize ("device" , valid_dev , ids = dev_ids )
672
+ def test_matmul_syrk (self , device ):
673
+ a = dpnp .arange (20 , dtype = dpnp .float32 , device = device ).reshape (4 , 5 )
674
+ result = dpnp .matmul (a , a .mT )
675
+ assert_sycl_queue_equal (result .sycl_queue , a .sycl_queue )
676
676
677
677
678
678
@pytest .mark .parametrize ("device" , valid_dev , ids = dev_ids )
0 commit comments