File tree Expand file tree Collapse file tree 1 file changed +5
-2
lines changed Expand file tree Collapse file tree 1 file changed +5
-2
lines changed Original file line number Diff line number Diff line change 3737from torch .testing import FileCheck
3838from torchao .quantization .granularity import PerAxis , PerGroup
3939from torchao .quantization .quant_api import IntxWeightOnlyConfig , quantize_
40+ from torchao .quantization .utils import compute_error
4041
4142
4243class TestQuantFusionPass (unittest .TestCase ):
@@ -470,7 +471,8 @@ def _test_embedding_torchao(
470471
471472 # Compare numerics
472473 actual_outputs = m .exported_program ().module ()(* example_inputs )
473- self .assertTrue (torch .allclose (expected_outputs , actual_outputs ))
474+ sqnr = compute_error (expected_outputs , actual_outputs )
475+ self .assertTrue (sqnr >= 50 , f"Got sqnr { sqnr } " )
474476
475477 # Can lower to executorch
476478 exec_prog = m .to_executorch () # noqa
@@ -488,7 +490,8 @@ def _test_embedding_torchao(
488490 )
489491
490492 actual_outputs2 = m_copy .exported_program ().module ()(* example_inputs )
491- self .assertTrue (torch .allclose (expected_outputs , actual_outputs2 ))
493+ sqnr = compute_error (expected_outputs , actual_outputs2 )
494+ self .assertTrue (sqnr >= 50 , f"Got sqnr { sqnr } " )
492495
493496 # Can lower to executorch
494497 exec_prog2 = m_copy .to_executorch () # noqa
You can’t perform that action at this time.
0 commit comments