Skip to content

Commit 74df338

Browse files
Use tolerance of 32 epsilon for gemm result comparisons
1 parent d7fc376 commit 74df338

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

dpctl/tests/test_usm_ndarray_linalg.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -591,8 +591,12 @@ def test_matmul_largish(dtype, random_matrix):
591591
x1 = dpt.matmul(m.mT, m)
592592
x2 = dpt.matmul(mT, m)
593593

594-
assert dpt.allclose(x1, x2)
595-
assert dpt.allclose(x1, dpt.asarray(x_np))
594+
tol = 0
595+
if dpt.isdtype(x2.dtype, ("real floating", "complex floating")):
596+
tol = 32 * dpt.finfo(x2.dtype).eps
597+
598+
assert dpt.allclose(x1, x2, atol=tol, rtol=tol)
599+
assert dpt.allclose(x1, dpt.asarray(x_np), atol=tol, rtol=tol)
596600

597601
m_np = m_np[:-1, :-1]
598602
x_np = np.matmul(m_np.T, m_np)
@@ -602,8 +606,8 @@ def test_matmul_largish(dtype, random_matrix):
602606
x1 = dpt.matmul(m.mT, m)
603607
x2 = dpt.matmul(mT, m)
604608

605-
assert dpt.allclose(x1, x2)
606-
assert dpt.allclose(x1, dpt.asarray(x_np))
609+
assert dpt.allclose(x1, x2, atol=tol, rtol=tol)
610+
assert dpt.allclose(x1, dpt.asarray(x_np), atol=tol, rtol=tol)
607611

608612

609613
@pytest.mark.parametrize("dtype", _numeric_types)

0 commit comments

Comments
 (0)