Skip to content

Commit dcad270

Browse files
[AMD] Fix test_dot_multidim floating-point comparision (#8780)
This PR fixes test failures in `test_dot_multidim` that occur on AMD RDNA3 GPUs due to overly strict floating-point comparisons. The test currently uses `torch.equal()` which requires exact bit-for-bit equality, but different GPU architectures and drivers can produce slightly different results at the machine epsilon level while still being mathematically correct. This change replaces the exact equality check with `torch.allclose()` using appropriate tolerances. Triton --> tensor([[[[ -3.0000, -44.0000, 8.0000, ..., 1.0000, -28.0000, -12.0000], [ -6.0000, 46.0000, 10.0000, ..., -13.0000, 35.0000, 50.0000], [ 83.0000, 5.0000, 59.0000, ..., -32.0000, -19.0000, 96.0000], ..., Torch ---> tensor([[[[ -3., -44., 8., ..., 1., -28., -12.], [ -6., 46., 10., ..., -13., 35., 50.], [ 83., 5., 59., ..., -32., -19., 96.], ..., The printed values appear identical, but `torch.equal()` detects differences at the least significant bits level.
1 parent 597b17e commit dcad270

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

python/test/unit/language/test_core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6693,4 +6693,4 @@ def kernel(X, Y, Z, RANK: tl.constexpr, TRANS_A: tl.constexpr, TRANS_B: tl.const
66936693

66946694
d = a.to(torch.float32) @ b.to(torch.float32)
66956695

6696-
assert torch.equal(c, d)
6696+
assert torch.allclose(c, d, rtol=1e-3, atol=1e-2)

0 commit comments

Comments
 (0)