|
29 | 29 | set_seed()
|
30 | 30 |
|
31 | 31 |
|
| 32 | +@pytest.fixture(autouse=True) |
| 33 | +def setup_seed(): |
| 34 | + """Set seed before each test function.""" |
| 35 | + set_seed() |
| 36 | + |
| 37 | + |
32 | 38 | @pytest.mark.parametrize(
|
33 | 39 | ("config", "gemm_forward", "atol", "rtol"),
|
34 | 40 | [
|
@@ -257,9 +263,9 @@ def forward_loop(model, run_backward=False):
|
257 | 263 |
|
258 | 264 | # The way the compression of the weights and inputs might be different.
|
259 | 265 | # E.g. we may use torch.compile in the gemms.
|
260 |
| - assert torch.allclose(output_dynamic_quant_gemm, output_dynamic_quant, atol=atol / 3) |
261 |
| - assert torch.allclose(output_calib_quant_gemm, output_calib_quant, atol=atol / 3) |
| 266 | + assert torch.allclose(output_dynamic_quant_gemm, output_dynamic_quant, atol=atol / 2) |
| 267 | + assert torch.allclose(output_calib_quant_gemm, output_calib_quant, atol=atol / 2) |
262 | 268 | assert torch.allclose(
|
263 |
| - output_dynamic_quant_gemm, output_dynamic_quant_compressed, atol=atol / 3 |
| 269 | + output_dynamic_quant_gemm, output_dynamic_quant_compressed, atol=atol / 2 |
264 | 270 | )
|
265 |
| - assert torch.allclose(output_calib_quant_gemm, output_calib_quant_compressed, atol=atol / 3) |
| 271 | + assert torch.allclose(output_calib_quant_gemm, output_calib_quant_compressed, atol=atol / 2) |
0 commit comments