diff --git a/tests/gpu/torch/quantization/backends/test_gemm_common.py b/tests/gpu/torch/quantization/backends/test_gemm_common.py index e58137d99..c3b20b3b2 100644 --- a/tests/gpu/torch/quantization/backends/test_gemm_common.py +++ b/tests/gpu/torch/quantization/backends/test_gemm_common.py @@ -29,6 +29,12 @@ set_seed() +@pytest.fixture(autouse=True) +def setup_seed(): + """Set seed before each test function.""" + set_seed() + + @pytest.mark.parametrize( ("config", "gemm_forward", "atol", "rtol"), [ @@ -257,9 +263,9 @@ def forward_loop(model, run_backward=False): # The way the compression of the weights and inputs might be different. # E.g. we may use torch.compile in the gemms. - assert torch.allclose(output_dynamic_quant_gemm, output_dynamic_quant, atol=atol / 3) - assert torch.allclose(output_calib_quant_gemm, output_calib_quant, atol=atol / 3) + assert torch.allclose(output_dynamic_quant_gemm, output_dynamic_quant, atol=atol / 2) + assert torch.allclose(output_calib_quant_gemm, output_calib_quant, atol=atol / 2) assert torch.allclose( - output_dynamic_quant_gemm, output_dynamic_quant_compressed, atol=atol / 3 + output_dynamic_quant_gemm, output_dynamic_quant_compressed, atol=atol / 2 ) - assert torch.allclose(output_calib_quant_gemm, output_calib_quant_compressed, atol=atol / 3) + assert torch.allclose(output_calib_quant_gemm, output_calib_quant_compressed, atol=atol / 2)