Skip to content

Commit a748476

Browse files
fix: triton unit test default was changed in prev PR, causing slightly increase in rel errs, change settings in tests to match what we had before
Signed-off-by: cliu-us <[email protected]>
1 parent 766c805 commit a748476

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

tests/triton_kernels/test_triton_mm.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,10 @@ def test_triton_matmul_fp(mkn, dtype_to_test):
6969
.to("cuda")
7070
.to(torch.float)
7171
)
72-
tl_output_no_trun = tl_matmul(a, b).to(torch.float)
73-
tl_output_trun_8b = tl_matmul(a, b, chunk_trun_bits=8).to(torch.float)
72+
tl_output_no_trun = tl_matmul(a, b, truncate_then_accumulate=False).to(torch.float)
73+
tl_output_trun_8b = tl_matmul(
74+
a, b, chunk_trun_bits=8, truncate_then_accumulate=False
75+
).to(torch.float)
7476

7577
diff_no_trun = torch_output - tl_output_no_trun
7678
diff_trun_8b = torch_output - tl_output_trun_8b

0 commit comments

Comments
 (0)