77# pyre-unsafe
88
99from collections import Counter
10- from typing import Dict , Tuple
10+ from typing import Tuple
1111
1212import torch
1313from executorch .backends .xnnpack .quantizer .xnnpack_quantizer import (
3636from torchao .quantization .pt2e import (
3737 allow_exported_model_train_eval ,
3838 compare_results ,
39- CUSTOM_KEY ,
4039 extract_results_from_loggers ,
41- generate_numeric_debug_handle ,
42- NUMERIC_DEBUG_HANDLE_KEY ,
40+ FROM_NODE_KEY ,
4341 prepare_for_propagation_comparison ,
4442)
4543
44+ from torchao .quantization .pt2e ._numeric_debugger import _generate_debug_handle_from_node
45+
4646from torchao .quantization .pt2e .graph_utils import bfs_trace_with_node_process
4747from torchao .quantization .pt2e .quantize_pt2e import (
4848 convert_pt2e ,
@@ -723,75 +723,101 @@ def test_save_load(self) -> None:
723723instantiate_parametrized_tests (TestQuantizePT2E )
724724
725725
726+ # TODO: deduplicate with TestNumericDebugger under torchao
726727class TestNumericDebugger (TestCase ):
727- def _extract_debug_handles (self , model ) -> Dict [str , int ]:
728- debug_handle_map : Dict [str , int ] = {}
728+ def _assert_each_node_has_debug_handle (self , model ) -> None :
729+ def _assert_node_has_debug_handle (node ):
730+ self .assertIn (
731+ FROM_NODE_KEY ,
732+ node .meta ,
733+ f"Node { node } doesn't have from_node info" ,
734+ )
735+
736+ bfs_trace_with_node_process (model , _assert_node_has_debug_handle )
737+
738+ def _extract_debug_handles (self , model ) -> dict [str , int ]:
739+ debug_handle_map : dict [str , int ] = {}
729740
730- def _extract_debug_handles_from_node (node : torch . fx . Node ) -> None :
741+ def _extract_debug_handles_from_node (node ) :
731742 nonlocal debug_handle_map
732- if (
733- CUSTOM_KEY in node .meta
734- and NUMERIC_DEBUG_HANDLE_KEY in node .meta [CUSTOM_KEY ]
735- ):
736- debug_handle_map [str (node )] = node .meta [CUSTOM_KEY ][
737- NUMERIC_DEBUG_HANDLE_KEY
738- ]
743+ if (dh := _generate_debug_handle_from_node (node )) is not None :
744+ debug_handle_map [str (node )] = dh
739745
740746 bfs_trace_with_node_process (model , _extract_debug_handles_from_node )
747+
741748 return debug_handle_map
742749
743- def _assert_each_node_has_debug_handle (self , model ) -> None :
744- def _assert_node_has_debug_handle (node : torch .fx .Node ) -> None :
745- self .assertTrue (
746- CUSTOM_KEY in node .meta
747- and NUMERIC_DEBUG_HANDLE_KEY in node .meta [CUSTOM_KEY ],
748- f"Node { node } doesn't have debug handle" ,
749- )
750+ def _extract_debug_handles_with_prev_decomp_op (self , model ) -> dict [str , int ]:
751+ prev_decomp_op_to_debug_handle_map : dict [str , int ] = {}
750752
751- bfs_trace_with_node_process (model , _assert_node_has_debug_handle )
753+ def _extract_debug_handles_with_prev_decomp_op_from_node (node ):
754+ nonlocal prev_decomp_op_to_debug_handle_map
755+ if FROM_NODE_KEY in node .meta :
756+ prev_decomp_op = str (node .meta .get ("nn_module_stack" ))
757+ debug_handle = _generate_debug_handle_from_node (node )
758+ if prev_decomp_op not in prev_decomp_op_to_debug_handle_map :
759+ prev_decomp_op_to_debug_handle_map [prev_decomp_op ] = debug_handle
760+ else :
761+ assert (
762+ prev_decomp_op_to_debug_handle_map [prev_decomp_op ]
763+ == debug_handle
764+ ), f"Node { node } has different debug handle { debug_handle } "
765+ "than previous node sharing the same decomp op {prev_decomp_op}"
766+
767+ bfs_trace_with_node_process (
768+ model , _extract_debug_handles_with_prev_decomp_op_from_node
769+ )
770+ return prev_decomp_op_to_debug_handle_map
752771
753- def test_quantize_pt2e_preserve_handle (self ) -> None :
772+ def test_quantize_pt2e_preserve_handle (self ):
754773 m = TestHelperModules .Conv2dThenConv1d ()
755774 example_inputs = m .example_inputs ()
756775 ep = export_for_training (m , example_inputs , strict = True )
757- generate_numeric_debug_handle (ep )
758776 m = ep .module ()
759777
760778 quantizer = XNNPACKQuantizer ().set_global (
761779 get_symmetric_quantization_config (is_per_channel = False )
762780 )
763- m = prepare_pt2e (m , quantizer ) # pyre-ignore[6]
781+ m = prepare_pt2e (m , quantizer )
764782 debug_handle_map = self ._extract_debug_handles (m )
783+ node_name_equip_with_output_observer = [
784+ "conv2d" ,
785+ "conv1d" ,
786+ "squeeze" ,
787+ ]
765788 res_counter = Counter (debug_handle_map .values ())
766- repeated_debug_handle_ids = [1 , 2 , 3 ]
789+ repeated_debug_handle_ids = [
790+ debug_handle_map [n_name ] for n_name in node_name_equip_with_output_observer
791+ ]
767792 # 3 ids were repeated because we copy over the id from node to its output observer
768793 # torch.ops.aten.conv2d.default, torch.ops.aten.squeeze.dim and torch.ops.aten.conv1d.default
769794 for dh_id in repeated_debug_handle_ids :
770795 self .assertEqual (res_counter [dh_id ], 2 )
771796
772797 m (* example_inputs )
773798 m = convert_pt2e (m )
774- self ._assert_each_node_has_debug_handle (ep )
799+ self ._assert_each_node_has_debug_handle (m )
775800 debug_handle_map = self ._extract_debug_handles (m )
776801 res_counter = Counter (debug_handle_map .values ())
777802 # same set of ids where repeated, because we copy over the id from observer/fake_quant to
778- # dequantize node
779- repeated_debug_handle_ids = [1 , 2 , 3 ]
803+ # quantize/dequantize node
804+ repeated_debug_handle_ids = [
805+ debug_handle_map [n_name ] for n_name in node_name_equip_with_output_observer
806+ ]
780807 for dh_id in repeated_debug_handle_ids :
781- self .assertEqual (res_counter [dh_id ], 2 )
808+ self .assertEqual (res_counter [dh_id ], 3 )
782809
783- def test_extract_results_from_loggers (self ) -> None :
810+ def test_extract_results_from_loggers (self ):
784811 m = TestHelperModules .Conv2dThenConv1d ()
785812 example_inputs = m .example_inputs ()
786813 ep = export_for_training (m , example_inputs , strict = True )
787- generate_numeric_debug_handle (ep )
788814 m = ep .module ()
789- m_ref_logger = prepare_for_propagation_comparison (m ) # pyre-ignore[6]
815+ m_ref_logger = prepare_for_propagation_comparison (m )
790816
791817 quantizer = XNNPACKQuantizer ().set_global (
792818 get_symmetric_quantization_config (is_per_channel = False )
793819 )
794- m = prepare_pt2e (m , quantizer ) # pyre-ignore[6]
820+ m = prepare_pt2e (m , quantizer )
795821 m (* example_inputs )
796822 m = convert_pt2e (m )
797823 m_quant_logger = prepare_for_propagation_comparison (m )
@@ -800,29 +826,22 @@ def test_extract_results_from_loggers(self) -> None:
800826 m_quant_logger (* example_inputs )
801827 ref_results = extract_results_from_loggers (m_ref_logger )
802828 quant_results = extract_results_from_loggers (m_quant_logger )
803- comparison_results = compare_results (
804- ref_results ,
805- quant_results , # pyre-ignore[6]
806- )
829+ comparison_results = compare_results (ref_results , quant_results )
807830 for node_summary in comparison_results .values ():
808831 if len (node_summary .results ) > 0 :
809- self .assertGreaterEqual (
810- node_summary .results [0 ].sqnr ,
811- 35 , # pyre-ignore[6]
812- )
832+ self .assertGreaterEqual (node_summary .results [0 ].sqnr , 35 )
813833
814- def test_extract_results_from_loggers_list_output (self ) -> None :
834+ def test_extract_results_from_loggers_list_output (self ):
815835 m = TestHelperModules .Conv2dWithSplit ()
816836 example_inputs = m .example_inputs ()
817837 ep = export_for_training (m , example_inputs , strict = True )
818- generate_numeric_debug_handle (ep )
819838 m = ep .module ()
820- m_ref_logger = prepare_for_propagation_comparison (m ) # pyre-ignore[6]
839+ m_ref_logger = prepare_for_propagation_comparison (m )
821840
822841 quantizer = XNNPACKQuantizer ().set_global (
823842 get_symmetric_quantization_config (is_per_channel = False )
824843 )
825- m = prepare_pt2e (m , quantizer ) # pyre-ignore[6]
844+ m = prepare_pt2e (m , quantizer )
826845 m (* example_inputs )
827846 m = convert_pt2e (m )
828847 m_quant_logger = prepare_for_propagation_comparison (m )
@@ -831,15 +850,12 @@ def test_extract_results_from_loggers_list_output(self) -> None:
831850 m_quant_logger (* example_inputs )
832851 ref_results = extract_results_from_loggers (m_ref_logger )
833852 quant_results = extract_results_from_loggers (m_quant_logger )
834- comparison_results = compare_results (
835- ref_results ,
836- quant_results , # pyre-ignore[6]
837- )
853+ comparison_results = compare_results (ref_results , quant_results )
838854 for node_summary in comparison_results .values ():
839855 if len (node_summary .results ) > 0 :
840856 sqnr = node_summary .results [0 ].sqnr
841857 if isinstance (sqnr , list ):
842858 for sqnr_i in sqnr :
843859 self .assertGreaterEqual (sqnr_i , 35 )
844860 else :
845- self .assertGreaterEqual (sqnr , 35 ) # pyre-ignore[6]
861+ self .assertGreaterEqual (sqnr , 35 )
0 commit comments