From e500c752f5d98ad5d101d97c55f1ae9e537caf3b Mon Sep 17 00:00:00 2001 From: Chenjie Luo <108829653+cjluo-nv@users.noreply.github.com> Date: Thu, 4 Sep 2025 08:17:06 -0700 Subject: [PATCH] Fix numerical stability of test_gemm_common.py Relax tensor compare threshold Signed-off-by: Chenjie Luo <108829653+cjluo-nv@users.noreply.github.com> --- .../quantization/backends/test_gemm_common.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/tests/gpu/torch/quantization/backends/test_gemm_common.py b/tests/gpu/torch/quantization/backends/test_gemm_common.py index e58137d99..c3b20b3b2 100644 --- a/tests/gpu/torch/quantization/backends/test_gemm_common.py +++ b/tests/gpu/torch/quantization/backends/test_gemm_common.py @@ -29,6 +29,12 @@ set_seed() +@pytest.fixture(autouse=True) +def setup_seed(): + """Set seed before each test function.""" + set_seed() + + @pytest.mark.parametrize( ("config", "gemm_forward", "atol", "rtol"), [ @@ -257,9 +263,9 @@ def forward_loop(model, run_backward=False): # The way the compression of the weights and inputs might be different. # E.g. we may use torch.compile in the gemms. - assert torch.allclose(output_dynamic_quant_gemm, output_dynamic_quant, atol=atol / 3) - assert torch.allclose(output_calib_quant_gemm, output_calib_quant, atol=atol / 3) + assert torch.allclose(output_dynamic_quant_gemm, output_dynamic_quant, atol=atol / 2) + assert torch.allclose(output_calib_quant_gemm, output_calib_quant, atol=atol / 2) assert torch.allclose( - output_dynamic_quant_gemm, output_dynamic_quant_compressed, atol=atol / 3 + output_dynamic_quant_gemm, output_dynamic_quant_compressed, atol=atol / 2 ) - assert torch.allclose(output_calib_quant_gemm, output_calib_quant_compressed, atol=atol / 3) + assert torch.allclose(output_calib_quant_gemm, output_calib_quant_compressed, atol=atol / 2)