@@ -458,45 +458,45 @@ def test_coerced_usm_types_bitwise_op(op, usm_type_x, usm_type_y):
458458 assert z .usm_type == du .get_coerced_usm_type ([usm_type_x , usm_type_y ])
459459
460460
461- # @pytest.mark.parametrize("usm_type_x", list_of_usm_types, ids=list_of_usm_types)
462- # @pytest.mark.parametrize("usm_type_y", list_of_usm_types, ids=list_of_usm_types)
463- # @pytest.mark.parametrize(
464- # "shape_pair",
465- # [
466- # ((2, 4), (4,)),
467- # ((4,), (4, 3)),
468- # ((2, 4), (4, 3)),
469- # ((2, 0), (0, 3)),
470- # ((2, 4), (4, 0)),
471- # ((4, 2, 3), (4, 3, 5)),
472- # ((4, 2, 3), (4, 3, 1)),
473- # ((4, 1, 3), (4, 3, 5)),
474- # ((6, 7, 4, 3), (6, 7, 3, 5)),
475- # ],
476- # ids=[
477- # "((2, 4), (4,))",
478- # "((4,), (4, 3))",
479- # "((2, 4), (4, 3))",
480- # "((2, 0), (0, 3))",
481- # "((2, 4), (4, 0))",
482- # "((4, 2, 3), (4, 3, 5))",
483- # "((4, 2, 3), (4, 3, 1))",
484- # "((4, 1, 3), (4, 3, 5))",
485- # "((6, 7, 4, 3), (6, 7, 3, 5))",
486- # ],
487- # )
488- # def test_matmul(usm_type_x, usm_type_y, shape_pair):
489- # shape1, shape2 = shape_pair
490- # x = numpy.arange(numpy.prod(shape1)).reshape(shape1)
491- # y = numpy.arange(numpy.prod(shape2)).reshape(shape2)
461+ @pytest .mark .parametrize ("usm_type_x" , list_of_usm_types , ids = list_of_usm_types )
462+ @pytest .mark .parametrize ("usm_type_y" , list_of_usm_types , ids = list_of_usm_types )
463+ @pytest .mark .parametrize (
464+ "shape_pair" ,
465+ [
466+ ((2 , 4 ), (4 ,)),
467+ ((4 ,), (4 , 3 )),
468+ ((2 , 4 ), (4 , 3 )),
469+ ((2 , 0 ), (0 , 3 )),
470+ ((2 , 4 ), (4 , 0 )),
471+ ((4 , 2 , 3 ), (4 , 3 , 5 )),
472+ ((4 , 2 , 3 ), (4 , 3 , 1 )),
473+ ((4 , 1 , 3 ), (4 , 3 , 5 )),
474+ ((6 , 7 , 4 , 3 ), (6 , 7 , 3 , 5 )),
475+ ],
476+ ids = [
477+ "((2, 4), (4,))" ,
478+ "((4,), (4, 3))" ,
479+ "((2, 4), (4, 3))" ,
480+ "((2, 0), (0, 3))" ,
481+ "((2, 4), (4, 0))" ,
482+ "((4, 2, 3), (4, 3, 5))" ,
483+ "((4, 2, 3), (4, 3, 1))" ,
484+ "((4, 1, 3), (4, 3, 5))" ,
485+ "((6, 7, 4, 3), (6, 7, 3, 5))" ,
486+ ],
487+ )
488+ def test_matmul (usm_type_x , usm_type_y , shape_pair ):
489+ shape1 , shape2 = shape_pair
490+ x = numpy .arange (numpy .prod (shape1 )).reshape (shape1 )
491+ y = numpy .arange (numpy .prod (shape2 )).reshape (shape2 )
492492
493- # x = dp.array(x, usm_type=usm_type_x)
494- # y = dp.array(y, usm_type=usm_type_y)
495- # z = dp.matmul(x, y)
493+ x = dp .array (x , usm_type = usm_type_x )
494+ y = dp .array (y , usm_type = usm_type_y )
495+ z = dp .matmul (x , y )
496496
497- # assert x.usm_type == usm_type_x
498- # assert y.usm_type == usm_type_y
499- # assert z.usm_type == du.get_coerced_usm_type([usm_type_x, usm_type_y])
497+ assert x .usm_type == usm_type_x
498+ assert y .usm_type == usm_type_y
499+ assert z .usm_type == du .get_coerced_usm_type ([usm_type_x , usm_type_y ])
500500
501501
502502# @pytest.mark.parametrize("usm_type_x", list_of_usm_types, ids=list_of_usm_types)
0 commit comments