1212from executorch import exir
1313from executorch .exir import EdgeCompileConfig , to_edge
1414from executorch .exir .passes .constant_prop_pass import constant_prop_pass
15- from executorch .exir .passes .quant_fusion_pass import QuantFusionPass
15+ from executorch .exir .passes .quant_fusion_pass import QuantFusionPass , quant_fusion_and_const_prop_pass
1616from executorch .exir .tests .common import register_additional_test_aten_ops
1717from torch .ao .quantization import ( # @manual
1818 float_qparams_weight_only_qconfig ,
3333from torch .testing import FileCheck
3434from torchao .quantization .granularity import PerAxis , PerGroup
3535from torchao .quantization .quant_api import IntxWeightOnlyConfig , quantize_
36-
36+ import copy
3737
3838class TestQuantFusionPass (unittest .TestCase ):
3939 @classmethod
@@ -419,6 +419,7 @@ def _test_embedding_torchao(
419419 m = to_edge (
420420 export (model , example_inputs , strict = True ), compile_config = compile_config
421421 )
422+ m_copy = copy .deepcopy (m )
422423
423424 # Before pass, we see torchao dequantize and embedding ops
424425 FileCheck ().check_count (
@@ -437,13 +438,9 @@ def _test_embedding_torchao(
437438
438439 # After pass, we see packing op and quantized embedding op, but no torchao dequantize op
439440 FileCheck ().check_count (
440- "executorch_exir_dialects_edge__ops_quant_fusion__pack_embedding_weight_default" ,
441- 1 if bit_width < 8 else 0 ,
442- exactly = True ,
441+ "executorch_exir_dialects_edge__ops_quant_fusion__pack_embedding_weight_default" , 1 if bit_width < 8 else 0 , exactly = True
443442 ).check_count (
444- f"executorch_exir_dialects_edge__ops_quantized_decomposed_embedding_{ embedding_suffix } " ,
445- 1 ,
446- exactly = True ,
443+ f"executorch_exir_dialects_edge__ops_quantized_decomposed_embedding_{ embedding_suffix } " , 1 , exactly = True ,
447444 ).check_not (
448445 "executorch_exir_dialects_edge__ops_torchao_dequantize_affine_default"
449446 ).run (
@@ -454,9 +451,7 @@ def _test_embedding_torchao(
454451
455452 # After constant prop, we see quantized embedding op, but no packing op
456453 FileCheck ().check_count (
457- f"executorch_exir_dialects_edge__ops_quantized_decomposed_embedding_{ embedding_suffix } " ,
458- 1 ,
459- exactly = True ,
454+ f"executorch_exir_dialects_edge__ops_quantized_decomposed_embedding_{ embedding_suffix } " , 1 , exactly = True ,
460455 ).check_not (
461456 "executorch_exir_dialects_edge__ops_quant_fusion__pack_embedding_weight_default" ,
462457 ).run (
@@ -468,4 +463,21 @@ def _test_embedding_torchao(
468463 self .assertTrue (torch .allclose (expected_outputs , actual_outputs ))
469464
470465 # Can lower to executorch
471- exec_prog = m .to_executorch () # noqa: F841
466+ exec_prog = m .to_executorch () # noqa
467+
468+
469+ # Alternative flow 2 using quant_fusion_pass on exported program
470+ quant_fusion_and_const_prop_pass (m_copy .exported_program ())
471+ FileCheck ().check_count (
472+ f"executorch_exir_dialects_edge__ops_quantized_decomposed_embedding_{ embedding_suffix } " , 1 , exactly = True ,
473+ ).check_not (
474+ "executorch_exir_dialects_edge__ops_quant_fusion__pack_embedding_weight_default" ,
475+ ).run (
476+ m_copy .exported_program ().graph_module .code
477+ )
478+
479+ actual_outputs2 = m_copy .exported_program ().module ()(* example_inputs )
480+ self .assertTrue (torch .allclose (expected_outputs , actual_outputs2 ))
481+
482+ # Can lower to executorch
483+ exec_prog2 = m_copy .to_executorch () # noqa
0 commit comments