Skip to content

Commit 2b52759

Browse files
authored
Fix numerical stability of test_gemm_common.py (#283)
Signed-off-by: Chenjie Luo <[email protected]>
1 parent 76fb12d commit 2b52759

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

tests/gpu/torch/quantization/backends/test_gemm_common.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,12 @@
2929
set_seed()
3030

3131

32+
@pytest.fixture(autouse=True)
33+
def setup_seed():
34+
"""Set seed before each test function."""
35+
set_seed()
36+
37+
3238
@pytest.mark.parametrize(
3339
("config", "gemm_forward", "atol", "rtol"),
3440
[
@@ -257,9 +263,9 @@ def forward_loop(model, run_backward=False):
257263

258264
# The way the compression of the weights and inputs might be different.
259265
# E.g. we may use torch.compile in the gemms.
260-
assert torch.allclose(output_dynamic_quant_gemm, output_dynamic_quant, atol=atol / 3)
261-
assert torch.allclose(output_calib_quant_gemm, output_calib_quant, atol=atol / 3)
266+
assert torch.allclose(output_dynamic_quant_gemm, output_dynamic_quant, atol=atol / 2)
267+
assert torch.allclose(output_calib_quant_gemm, output_calib_quant, atol=atol / 2)
262268
assert torch.allclose(
263-
output_dynamic_quant_gemm, output_dynamic_quant_compressed, atol=atol / 3
269+
output_dynamic_quant_gemm, output_dynamic_quant_compressed, atol=atol / 2
264270
)
265-
assert torch.allclose(output_calib_quant_gemm, output_calib_quant_compressed, atol=atol / 3)
271+
assert torch.allclose(output_calib_quant_gemm, output_calib_quant_compressed, atol=atol / 2)

0 commit comments

Comments
 (0)