66
77# pyre-unsafe
88
9- import unittest
10-
119from collections import Counter
12- from typing import Dict , Tuple
10+ from typing import Tuple
1311
1412import torch
1513from executorch .backends .xnnpack .quantizer .xnnpack_quantizer import (
3331from torch .testing ._internal .common_utils import (
3432 instantiate_parametrized_tests ,
3533 TemporaryFileName ,
36- TestCase ,
3734)
3835from torchao .quantization .pt2e import (
3936 allow_exported_model_train_eval ,
4037 compare_results ,
41- CUSTOM_KEY ,
4238 extract_results_from_loggers ,
43- generate_numeric_debug_handle ,
44- NUMERIC_DEBUG_HANDLE_KEY ,
39+ FROM_NODE_KEY ,
4540 prepare_for_propagation_comparison ,
4641)
4742
48- from torchao .quantization .pt2e .graph_utils import bfs_trace_with_node_process
4943from torchao .quantization .pt2e .quantize_pt2e import (
5044 convert_pt2e ,
5145 prepare_pt2e ,
5246 prepare_qat_pt2e ,
5347)
5448from torchao .quantization .pt2e .quantizer import ComposableQuantizer , Quantizer
5549from torchao .quantization .pt2e .quantizer .embedding_quantizer import EmbeddingQuantizer
56- from torchao .testing .pt2e .utils import PT2EQuantizationTestCase
50+ from torchao .testing .pt2e .utils import (
51+ PT2ENumericDebuggerTestCase ,
52+ PT2EQuantizationTestCase ,
53+ )
5754
5855
5956class TestQuantizePT2E (PT2EQuantizationTestCase ):
@@ -495,7 +492,8 @@ def forward(self, x):
495492 for n in m .graph .nodes :
496493 if n .op == "get_attr" and "frozen_param" in n .target :
497494 for key in n .meta :
498- self .assertEqual (n .meta [key ], weight_meta [key ])
495+ if key != FROM_NODE_KEY :
496+ self .assertEqual (n .meta [key ], weight_meta [key ])
499497
500498 def test_reentrant (self ) -> None :
501499 """Test we can safely call quantization apis multiple times"""
@@ -724,77 +722,59 @@ def test_save_load(self) -> None:
724722
725723instantiate_parametrized_tests (TestQuantizePT2E )
726724
725+ class TestXNNPACKQuantizerNumericDebugger (PT2ENumericDebuggerTestCase ):
727726
728- @unittest .skip ("TODO: Reenable it after debug infrature finish update" )
729- class TestNumericDebugger (TestCase ):
730- def _extract_debug_handles (self , model ) -> Dict [str , int ]:
731- debug_handle_map : Dict [str , int ] = {}
732-
733- def _extract_debug_handles_from_node (node : torch .fx .Node ) -> None :
734- nonlocal debug_handle_map
735- if (
736- CUSTOM_KEY in node .meta
737- and NUMERIC_DEBUG_HANDLE_KEY in node .meta [CUSTOM_KEY ]
738- ):
739- debug_handle_map [str (node )] = node .meta [CUSTOM_KEY ][
740- NUMERIC_DEBUG_HANDLE_KEY
741- ]
742-
743- bfs_trace_with_node_process (model , _extract_debug_handles_from_node )
744- return debug_handle_map
745-
746- def _assert_each_node_has_debug_handle (self , model ) -> None :
747- def _assert_node_has_debug_handle (node : torch .fx .Node ) -> None :
748- self .assertTrue (
749- CUSTOM_KEY in node .meta
750- and NUMERIC_DEBUG_HANDLE_KEY in node .meta [CUSTOM_KEY ],
751- f"Node { node } doesn't have debug handle" ,
752- )
753-
754- bfs_trace_with_node_process (model , _assert_node_has_debug_handle )
755-
756- def test_quantize_pt2e_preserve_handle (self ) -> None :
727+ def test_quantize_pt2e_preserve_handle (self ):
757728 m = TestHelperModules .Conv2dThenConv1d ()
758729 example_inputs = m .example_inputs ()
759730 ep = export_for_training (m , example_inputs , strict = True )
760- generate_numeric_debug_handle (ep )
761731 m = ep .module ()
762732
763733 quantizer = XNNPACKQuantizer ().set_global (
764734 get_symmetric_quantization_config (is_per_channel = False )
765735 )
766- m = prepare_pt2e (m , quantizer ) # pyre-ignore[6]
767- debug_handle_map = self ._extract_debug_handles (m )
768- res_counter = Counter (debug_handle_map .values ())
769- repeated_debug_handle_ids = [1 , 2 , 3 ]
770- # 3 ids were repeated because we copy over the id from node to its output observer
736+ m = prepare_pt2e (m , quantizer )
737+ from_node_source_map = self ._extract_from_node_source (m )
738+ node_name_equip_with_output_observer = [
739+ "conv2d" ,
740+ "conv1d" ,
741+ "squeeze" ,
742+ ]
743+ res_counter = Counter (from_node_source_map .values ())
744+ repeated_from_node_source = [
745+ from_node_source_map [n_name ]
746+ for n_name in node_name_equip_with_output_observer
747+ ]
748+ # 3 infos were repeated because we copy over the info from node to its output observer
771749 # torch.ops.aten.conv2d.default, torch.ops.aten.squeeze.dim and torch.ops.aten.conv1d.default
772- for dh_id in repeated_debug_handle_ids :
773- self .assertEqual (res_counter [dh_id ], 2 )
750+ for from_node_source in repeated_from_node_source :
751+ self .assertEqual (res_counter [from_node_source ], 2 )
774752
775753 m (* example_inputs )
776754 m = convert_pt2e (m )
777- self ._assert_each_node_has_debug_handle (ep )
778- debug_handle_map = self ._extract_debug_handles (m )
779- res_counter = Counter (debug_handle_map .values ())
780- # same set of ids where repeated, because we copy over the id from observer/fake_quant to
781- # dequantize node
782- repeated_debug_handle_ids = [1 , 2 , 3 ]
783- for dh_id in repeated_debug_handle_ids :
784- self .assertEqual (res_counter [dh_id ], 2 )
785-
786- def test_extract_results_from_loggers (self ) -> None :
755+ self ._assert_each_node_has_from_node_source (m )
756+ from_node_source_map = self ._extract_from_node_source (m )
757+ res_counter = Counter (from_node_source_map .values ())
758+ # same set of infos where repeated, because we copy over the info from observer/fake_quant to
759+ # quantize/dequantize node
760+ repeated_from_node_source = [
761+ from_node_source_map [n_name ]
762+ for n_name in node_name_equip_with_output_observer
763+ ]
764+ for from_node_source in repeated_from_node_source :
765+ self .assertEqual (res_counter [from_node_source ], 3 )
766+
767+ def test_extract_results_from_loggers (self ):
787768 m = TestHelperModules .Conv2dThenConv1d ()
788769 example_inputs = m .example_inputs ()
789770 ep = export_for_training (m , example_inputs , strict = True )
790- generate_numeric_debug_handle (ep )
791771 m = ep .module ()
792- m_ref_logger = prepare_for_propagation_comparison (m ) # pyre-ignore[6]
772+ m_ref_logger = prepare_for_propagation_comparison (m )
793773
794774 quantizer = XNNPACKQuantizer ().set_global (
795775 get_symmetric_quantization_config (is_per_channel = False )
796776 )
797- m = prepare_pt2e (m , quantizer ) # pyre-ignore[6]
777+ m = prepare_pt2e (m , quantizer )
798778 m (* example_inputs )
799779 m = convert_pt2e (m )
800780 m_quant_logger = prepare_for_propagation_comparison (m )
@@ -803,29 +783,22 @@ def test_extract_results_from_loggers(self) -> None:
803783 m_quant_logger (* example_inputs )
804784 ref_results = extract_results_from_loggers (m_ref_logger )
805785 quant_results = extract_results_from_loggers (m_quant_logger )
806- comparison_results = compare_results (
807- ref_results ,
808- quant_results , # pyre-ignore[6]
809- )
786+ comparison_results = compare_results (ref_results , quant_results )
810787 for node_summary in comparison_results .values ():
811788 if len (node_summary .results ) > 0 :
812- self .assertGreaterEqual (
813- node_summary .results [0 ].sqnr ,
814- 35 , # pyre-ignore[6]
815- )
789+ self .assertGreaterEqual (node_summary .results [0 ].sqnr , 35 )
816790
817- def test_extract_results_from_loggers_list_output (self ) -> None :
791+ def test_extract_results_from_loggers_list_output (self ):
818792 m = TestHelperModules .Conv2dWithSplit ()
819793 example_inputs = m .example_inputs ()
820794 ep = export_for_training (m , example_inputs , strict = True )
821- generate_numeric_debug_handle (ep )
822795 m = ep .module ()
823- m_ref_logger = prepare_for_propagation_comparison (m ) # pyre-ignore[6]
796+ m_ref_logger = prepare_for_propagation_comparison (m )
824797
825798 quantizer = XNNPACKQuantizer ().set_global (
826799 get_symmetric_quantization_config (is_per_channel = False )
827800 )
828- m = prepare_pt2e (m , quantizer ) # pyre-ignore[6]
801+ m = prepare_pt2e (m , quantizer )
829802 m (* example_inputs )
830803 m = convert_pt2e (m )
831804 m_quant_logger = prepare_for_propagation_comparison (m )
@@ -834,15 +807,12 @@ def test_extract_results_from_loggers_list_output(self) -> None:
834807 m_quant_logger (* example_inputs )
835808 ref_results = extract_results_from_loggers (m_ref_logger )
836809 quant_results = extract_results_from_loggers (m_quant_logger )
837- comparison_results = compare_results (
838- ref_results ,
839- quant_results , # pyre-ignore[6]
840- )
810+ comparison_results = compare_results (ref_results , quant_results )
841811 for node_summary in comparison_results .values ():
842812 if len (node_summary .results ) > 0 :
843813 sqnr = node_summary .results [0 ].sqnr
844814 if isinstance (sqnr , list ):
845815 for sqnr_i in sqnr :
846816 self .assertGreaterEqual (sqnr_i , 35 )
847817 else :
848- self .assertGreaterEqual (sqnr , 35 ) # pyre-ignore[6]
818+ self .assertGreaterEqual (sqnr , 35 )
0 commit comments