File tree Expand file tree Collapse file tree 1 file changed +5
-2
lines changed Expand file tree Collapse file tree 1 file changed +5
-2
lines changed Original file line number Diff line number Diff line change 37
37
from torch .testing import FileCheck
38
38
from torchao .quantization .granularity import PerAxis , PerGroup
39
39
from torchao .quantization .quant_api import IntxWeightOnlyConfig , quantize_
40
+ from torchao .quantization .utils import compute_error
40
41
41
42
42
43
class TestQuantFusionPass (unittest .TestCase ):
@@ -470,7 +471,8 @@ def _test_embedding_torchao(
470
471
471
472
# Compare numerics
472
473
actual_outputs = m .exported_program ().module ()(* example_inputs )
473
- self .assertTrue (torch .allclose (expected_outputs , actual_outputs ))
474
+ sqnr = compute_error (expected_outputs , actual_outputs )
475
+ self .assertTrue (sqnr >= 50 , f"Got sqnr { sqnr } " )
474
476
475
477
# Can lower to executorch
476
478
exec_prog = m .to_executorch () # noqa
@@ -488,7 +490,8 @@ def _test_embedding_torchao(
488
490
)
489
491
490
492
actual_outputs2 = m_copy .exported_program ().module ()(* example_inputs )
491
- self .assertTrue (torch .allclose (expected_outputs , actual_outputs2 ))
493
+ sqnr = compute_error (expected_outputs , actual_outputs2 )
494
+ self .assertTrue (sqnr >= 50 , f"Got sqnr { sqnr } " )
492
495
493
496
# Can lower to executorch
494
497
exec_prog2 = m_copy .to_executorch () # noqa
You can’t perform that action at this time.
0 commit comments