Skip to content

Commit 594fef6

Browse files
authored
Fix flakey quant fusion test (#14547)
As titled, we fix the flakey test by using SQNR >= 50 rather than torch.allclose
1 parent 9283b4e commit 594fef6

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

exir/tests/test_quant_fusion_pass.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from torch.testing import FileCheck
3838
from torchao.quantization.granularity import PerAxis, PerGroup
3939
from torchao.quantization.quant_api import IntxWeightOnlyConfig, quantize_
40+
from torchao.quantization.utils import compute_error
4041

4142

4243
class TestQuantFusionPass(unittest.TestCase):
@@ -470,7 +471,8 @@ def _test_embedding_torchao(
470471

471472
# Compare numerics
472473
actual_outputs = m.exported_program().module()(*example_inputs)
473-
self.assertTrue(torch.allclose(expected_outputs, actual_outputs))
474+
sqnr = compute_error(expected_outputs, actual_outputs)
475+
self.assertTrue(sqnr >= 50, f"Got sqnr {sqnr}")
474476

475477
# Can lower to executorch
476478
exec_prog = m.to_executorch() # noqa
@@ -488,7 +490,8 @@ def _test_embedding_torchao(
488490
)
489491

490492
actual_outputs2 = m_copy.exported_program().module()(*example_inputs)
491-
self.assertTrue(torch.allclose(expected_outputs, actual_outputs2))
493+
sqnr = compute_error(expected_outputs, actual_outputs2)
494+
self.assertTrue(sqnr >= 50, f"Got sqnr {sqnr}")
492495

493496
# Can lower to executorch
494497
exec_prog2 = m_copy.to_executorch() # noqa

0 commit comments

Comments
 (0)