66
77# pyre-strict
88
9+ import copy
910import unittest
1011
1112import torch
1213from executorch import exir
1314from executorch .exir import EdgeCompileConfig , to_edge
1415from executorch .exir .passes .constant_prop_pass import constant_prop_pass
15- from executorch .exir .passes .quant_fusion_pass import QuantFusionPass , quant_fusion_and_const_prop_pass
16+ from executorch .exir .passes .quant_fusion_pass import (
17+ quant_fusion_and_const_prop_pass ,
18+ QuantFusionPass ,
19+ )
1620from executorch .exir .tests .common import register_additional_test_aten_ops
1721from torch .ao .quantization import ( # @manual
1822 float_qparams_weight_only_qconfig ,
3337from torch .testing import FileCheck
3438from torchao .quantization .granularity import PerAxis , PerGroup
3539from torchao .quantization .quant_api import IntxWeightOnlyConfig , quantize_
36- import copy
40+
3741
3842class TestQuantFusionPass (unittest .TestCase ):
3943 @classmethod
@@ -438,9 +442,13 @@ def _test_embedding_torchao(
438442
439443 # After pass, we see packing op and quantized embedding op, but no torchao dequantize op
440444 FileCheck ().check_count (
441- "executorch_exir_dialects_edge__ops_quant_fusion__pack_embedding_weight_default" , 1 if bit_width < 8 else 0 , exactly = True
445+ "executorch_exir_dialects_edge__ops_quant_fusion__pack_embedding_weight_default" ,
446+ 1 if bit_width < 8 else 0 ,
447+ exactly = True ,
442448 ).check_count (
443- f"executorch_exir_dialects_edge__ops_quantized_decomposed_embedding_{ embedding_suffix } " , 1 , exactly = True ,
449+ f"executorch_exir_dialects_edge__ops_quantized_decomposed_embedding_{ embedding_suffix } " ,
450+ 1 ,
451+ exactly = True ,
444452 ).check_not (
445453 "executorch_exir_dialects_edge__ops_torchao_dequantize_affine_default"
446454 ).run (
@@ -451,7 +459,9 @@ def _test_embedding_torchao(
451459
452460 # After constant prop, we see quantized embedding op, but no packing op
453461 FileCheck ().check_count (
454- f"executorch_exir_dialects_edge__ops_quantized_decomposed_embedding_{ embedding_suffix } " , 1 , exactly = True ,
462+ f"executorch_exir_dialects_edge__ops_quantized_decomposed_embedding_{ embedding_suffix } " ,
463+ 1 ,
464+ exactly = True ,
455465 ).check_not (
456466 "executorch_exir_dialects_edge__ops_quant_fusion__pack_embedding_weight_default" ,
457467 ).run (
@@ -463,13 +473,14 @@ def _test_embedding_torchao(
463473 self .assertTrue (torch .allclose (expected_outputs , actual_outputs ))
464474
465475 # Can lower to executorch
466- exec_prog = m .to_executorch () # noqa
467-
476+ exec_prog = m .to_executorch () # noqa
468477
469478 # Alternative flow 2 using quant_fusion_pass on exported program
470479 quant_fusion_and_const_prop_pass (m_copy .exported_program ())
471480 FileCheck ().check_count (
472- f"executorch_exir_dialects_edge__ops_quantized_decomposed_embedding_{ embedding_suffix } " , 1 , exactly = True ,
481+ f"executorch_exir_dialects_edge__ops_quantized_decomposed_embedding_{ embedding_suffix } " ,
482+ 1 ,
483+ exactly = True ,
473484 ).check_not (
474485 "executorch_exir_dialects_edge__ops_quant_fusion__pack_embedding_weight_default" ,
475486 ).run (
@@ -480,4 +491,4 @@ def _test_embedding_torchao(
480491 self .assertTrue (torch .allclose (expected_outputs , actual_outputs2 ))
481492
482493 # Can lower to executorch
483- exec_prog2 = m_copy .to_executorch () # noqa
494+ exec_prog2 = m_copy .to_executorch () # noqa
0 commit comments