|
29 | 29 | set_seed() |
30 | 30 |
|
31 | 31 |
|
| 32 | +@pytest.fixture(autouse=True) |
| 33 | +def setup_seed(): |
| 34 | + """Set seed before each test function.""" |
| 35 | + set_seed() |
| 36 | + |
| 37 | + |
32 | 38 | @pytest.mark.parametrize( |
33 | 39 | ("config", "gemm_forward", "atol", "rtol"), |
34 | 40 | [ |
@@ -257,9 +263,9 @@ def forward_loop(model, run_backward=False): |
257 | 263 |
|
258 | 264 | # The way the compression of the weights and inputs might be different. |
259 | 265 | # E.g. we may use torch.compile in the gemms. |
260 | | - assert torch.allclose(output_dynamic_quant_gemm, output_dynamic_quant, atol=atol / 3) |
261 | | - assert torch.allclose(output_calib_quant_gemm, output_calib_quant, atol=atol / 3) |
| 266 | + assert torch.allclose(output_dynamic_quant_gemm, output_dynamic_quant, atol=atol / 2) |
| 267 | + assert torch.allclose(output_calib_quant_gemm, output_calib_quant, atol=atol / 2) |
262 | 268 | assert torch.allclose( |
263 | | - output_dynamic_quant_gemm, output_dynamic_quant_compressed, atol=atol / 3 |
| 269 | + output_dynamic_quant_gemm, output_dynamic_quant_compressed, atol=atol / 2 |
264 | 270 | ) |
265 | | - assert torch.allclose(output_calib_quant_gemm, output_calib_quant_compressed, atol=atol / 3) |
| 271 | + assert torch.allclose(output_calib_quant_gemm, output_calib_quant_compressed, atol=atol / 2) |
0 commit comments