Skip to content

Commit 88e4449

Browse files
committed
Enable tests for matmul
1 parent ee43c3b commit 88e4449

File tree

1 file changed

+37
-37
lines changed

1 file changed

+37
-37
lines changed

dpnp/tests/test_usm_type.py

Lines changed: 37 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -458,45 +458,45 @@ def test_coerced_usm_types_bitwise_op(op, usm_type_x, usm_type_y):
458458
assert z.usm_type == du.get_coerced_usm_type([usm_type_x, usm_type_y])
459459

460460

461-
# @pytest.mark.parametrize("usm_type_x", list_of_usm_types, ids=list_of_usm_types)
462-
# @pytest.mark.parametrize("usm_type_y", list_of_usm_types, ids=list_of_usm_types)
463-
# @pytest.mark.parametrize(
464-
# "shape_pair",
465-
# [
466-
# ((2, 4), (4,)),
467-
# ((4,), (4, 3)),
468-
# ((2, 4), (4, 3)),
469-
# ((2, 0), (0, 3)),
470-
# ((2, 4), (4, 0)),
471-
# ((4, 2, 3), (4, 3, 5)),
472-
# ((4, 2, 3), (4, 3, 1)),
473-
# ((4, 1, 3), (4, 3, 5)),
474-
# ((6, 7, 4, 3), (6, 7, 3, 5)),
475-
# ],
476-
# ids=[
477-
# "((2, 4), (4,))",
478-
# "((4,), (4, 3))",
479-
# "((2, 4), (4, 3))",
480-
# "((2, 0), (0, 3))",
481-
# "((2, 4), (4, 0))",
482-
# "((4, 2, 3), (4, 3, 5))",
483-
# "((4, 2, 3), (4, 3, 1))",
484-
# "((4, 1, 3), (4, 3, 5))",
485-
# "((6, 7, 4, 3), (6, 7, 3, 5))",
486-
# ],
487-
# )
488-
# def test_matmul(usm_type_x, usm_type_y, shape_pair):
489-
# shape1, shape2 = shape_pair
490-
# x = numpy.arange(numpy.prod(shape1)).reshape(shape1)
491-
# y = numpy.arange(numpy.prod(shape2)).reshape(shape2)
461+
@pytest.mark.parametrize("usm_type_x", list_of_usm_types, ids=list_of_usm_types)
462+
@pytest.mark.parametrize("usm_type_y", list_of_usm_types, ids=list_of_usm_types)
463+
@pytest.mark.parametrize(
464+
"shape_pair",
465+
[
466+
((2, 4), (4,)),
467+
((4,), (4, 3)),
468+
((2, 4), (4, 3)),
469+
((2, 0), (0, 3)),
470+
((2, 4), (4, 0)),
471+
((4, 2, 3), (4, 3, 5)),
472+
((4, 2, 3), (4, 3, 1)),
473+
((4, 1, 3), (4, 3, 5)),
474+
((6, 7, 4, 3), (6, 7, 3, 5)),
475+
],
476+
ids=[
477+
"((2, 4), (4,))",
478+
"((4,), (4, 3))",
479+
"((2, 4), (4, 3))",
480+
"((2, 0), (0, 3))",
481+
"((2, 4), (4, 0))",
482+
"((4, 2, 3), (4, 3, 5))",
483+
"((4, 2, 3), (4, 3, 1))",
484+
"((4, 1, 3), (4, 3, 5))",
485+
"((6, 7, 4, 3), (6, 7, 3, 5))",
486+
],
487+
)
488+
def test_matmul(usm_type_x, usm_type_y, shape_pair):
489+
shape1, shape2 = shape_pair
490+
x = numpy.arange(numpy.prod(shape1)).reshape(shape1)
491+
y = numpy.arange(numpy.prod(shape2)).reshape(shape2)
492492

493-
# x = dp.array(x, usm_type=usm_type_x)
494-
# y = dp.array(y, usm_type=usm_type_y)
495-
# z = dp.matmul(x, y)
493+
x = dp.array(x, usm_type=usm_type_x)
494+
y = dp.array(y, usm_type=usm_type_y)
495+
z = dp.matmul(x, y)
496496

497-
# assert x.usm_type == usm_type_x
498-
# assert y.usm_type == usm_type_y
499-
# assert z.usm_type == du.get_coerced_usm_type([usm_type_x, usm_type_y])
497+
assert x.usm_type == usm_type_x
498+
assert y.usm_type == usm_type_y
499+
assert z.usm_type == du.get_coerced_usm_type([usm_type_x, usm_type_y])
500500

501501

502502
# @pytest.mark.parametrize("usm_type_x", list_of_usm_types, ids=list_of_usm_types)

0 commit comments

Comments
 (0)