Skip to content

Commit 69ede62

Browse files
Added a test for matmul of medium size matrix
1 parent 3fb0608 commit 69ede62

File tree

1 file changed

+35
-0
lines changed

1 file changed

+35
-0
lines changed

dpctl/tests/test_usm_ndarray_linalg.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -571,6 +571,41 @@ def test_matmul_inplace_same_tensors():
571571
assert dpt.all(ar2 == dpt.full(sh, n, dtype=ar2.dtype))
572572

573573

574+
@pytest.fixture
575+
def random_matrix():
576+
rs = np.random.RandomState(seed=123456)
577+
m_np = rs.randint(low=0, high=6, size=(400, 400))
578+
return m_np
579+
580+
581+
@pytest.mark.parametrize("dtype", _numeric_types)
582+
def test_matmul_largish(dtype, random_matrix):
583+
q = get_queue_or_skip()
584+
skip_if_dtype_not_supported(dtype, q)
585+
586+
m_np = random_matrix.astype(dtype)
587+
x_np = np.matmul(m_np.T, m_np)
588+
589+
m = dpt.asarray(m_np)
590+
mT = dpt.asarray(m.mT, copy=True, order="C")
591+
x1 = dpt.matmul(m.mT, m)
592+
x2 = dpt.matmul(mT, m)
593+
594+
assert dpt.allclose(x1, x2)
595+
assert dpt.allclose(x1, dpt.asarray(x_np))
596+
597+
m_np = m_np[:-1, :-1]
598+
x_np = np.matmul(m_np.T, m_np)
599+
600+
m = m[:-1, :-1]
601+
mT = dpt.asarray(m.mT, copy=True, order="C")
602+
x1 = dpt.matmul(m.mT, m)
603+
x2 = dpt.matmul(mT, m)
604+
605+
assert dpt.allclose(x1, x2)
606+
assert dpt.allclose(x1, dpt.asarray(x_np))
607+
608+
574609
@pytest.mark.parametrize("dtype", _numeric_types)
575610
def test_tensordot_outer(dtype):
576611
q = get_queue_or_skip()

0 commit comments

Comments
 (0)