@@ -723,11 +723,22 @@ def test_save_load(self) -> None:
723723instantiate_parametrized_tests (TestQuantizePT2E )
724724
725725
726+ #TODO: deduplicate with TestNumericDebugger under torchao
726727class TestNumericDebugger (TestCase ):
727- def _extract_debug_handles (self , model ) -> Dict [str , int ]:
728- debug_handle_map : Dict [str , int ] = {}
728+ def _assert_each_node_has_debug_handle (self , model ) -> None :
729+ def _assert_node_has_debug_handle (node ):
730+ self .assertTrue (
731+ CUSTOM_KEY in node .meta
732+ and NUMERIC_DEBUG_HANDLE_KEY in node .meta [CUSTOM_KEY ],
733+ f"Node { node } doesn't have debug handle" ,
734+ )
735+
736+ bfs_trace_with_node_process (model , _assert_node_has_debug_handle )
729737
730- def _extract_debug_handles_from_node (node : torch .fx .Node ) -> None :
738+ def _extract_debug_handles (self , model ) -> dict [str , int ]:
739+ debug_handle_map : dict [str , int ] = {}
740+
741+ def _extract_debug_handles_from_node (node ):
731742 nonlocal debug_handle_map
732743 if (
733744 CUSTOM_KEY in node .meta
@@ -738,32 +749,55 @@ def _extract_debug_handles_from_node(node: torch.fx.Node) -> None:
738749 ]
739750
740751 bfs_trace_with_node_process (model , _extract_debug_handles_from_node )
752+
741753 return debug_handle_map
742754
743- def _assert_each_node_has_debug_handle (self , model ) -> None :
744- def _assert_node_has_debug_handle (node : torch .fx .Node ) -> None :
745- self .assertTrue (
746- CUSTOM_KEY in node .meta
747- and NUMERIC_DEBUG_HANDLE_KEY in node .meta [CUSTOM_KEY ],
748- f"Node { node } doesn't have debug handle" ,
749- )
755+ def _extract_debug_handles_with_prev_decomp_op (self , model ) -> dict [str , int ]:
756+ prev_decomp_op_to_debug_handle_map : dict [str , int ] = {}
750757
751- bfs_trace_with_node_process (model , _assert_node_has_debug_handle )
758+ def _extract_debug_handles_with_prev_decomp_op_from_node (node ):
759+ nonlocal prev_decomp_op_to_debug_handle_map
760+ if (
761+ CUSTOM_KEY in node .meta
762+ and NUMERIC_DEBUG_HANDLE_KEY in node .meta [CUSTOM_KEY ]
763+ ):
764+ prev_decomp_op = str (node .meta .get ("nn_module_stack" ))
765+ debug_handle = node .meta [CUSTOM_KEY ][NUMERIC_DEBUG_HANDLE_KEY ]
766+ if prev_decomp_op not in prev_decomp_op_to_debug_handle_map :
767+ prev_decomp_op_to_debug_handle_map [prev_decomp_op ] = debug_handle
768+ else :
769+ assert (
770+ prev_decomp_op_to_debug_handle_map [prev_decomp_op ]
771+ == debug_handle
772+ ), f"Node { node } has different debug handle { debug_handle } "
773+ "than previous node sharing the same decomp op {prev_decomp_op}"
774+
775+ bfs_trace_with_node_process (
776+ model , _extract_debug_handles_with_prev_decomp_op_from_node
777+ )
778+ return prev_decomp_op_to_debug_handle_map
752779
753- def test_quantize_pt2e_preserve_handle (self ) -> None :
780+ def test_quantize_pt2e_preserve_handle (self ):
754781 m = TestHelperModules .Conv2dThenConv1d ()
755782 example_inputs = m .example_inputs ()
756783 ep = export_for_training (m , example_inputs , strict = True )
757- generate_numeric_debug_handle (ep )
784+ # generate_numeric_debug_handle(ep)
758785 m = ep .module ()
759786
760787 quantizer = XNNPACKQuantizer ().set_global (
761788 get_symmetric_quantization_config (is_per_channel = False )
762789 )
763- m = prepare_pt2e (m , quantizer ) # pyre-ignore[6]
790+ m = prepare_pt2e (m , quantizer )
764791 debug_handle_map = self ._extract_debug_handles (m )
792+ node_name_equip_with_output_observer = [
793+ "conv2d" ,
794+ "conv1d" ,
795+ "squeeze" ,
796+ ]
765797 res_counter = Counter (debug_handle_map .values ())
766- repeated_debug_handle_ids = [1 , 2 , 3 ]
798+ repeated_debug_handle_ids = [
799+ debug_handle_map [n_name ] for n_name in node_name_equip_with_output_observer
800+ ]
767801 # 3 ids were repeated because we copy over the id from node to its output observer
768802 # torch.ops.aten.conv2d.default, torch.ops.aten.squeeze.dim and torch.ops.aten.conv1d.default
769803 for dh_id in repeated_debug_handle_ids :
@@ -776,22 +810,23 @@ def test_quantize_pt2e_preserve_handle(self) -> None:
776810 res_counter = Counter (debug_handle_map .values ())
777811 # same set of ids where repeated, because we copy over the id from observer/fake_quant to
778812 # dequantize node
779- repeated_debug_handle_ids = [1 , 2 , 3 ]
813+ repeated_debug_handle_ids = [
814+ debug_handle_map [n_name ] for n_name in node_name_equip_with_output_observer
815+ ]
780816 for dh_id in repeated_debug_handle_ids :
781817 self .assertEqual (res_counter [dh_id ], 2 )
782818
783- def test_extract_results_from_loggers (self ) -> None :
819+ def test_extract_results_from_loggers (self ):
784820 m = TestHelperModules .Conv2dThenConv1d ()
785821 example_inputs = m .example_inputs ()
786822 ep = export_for_training (m , example_inputs , strict = True )
787- generate_numeric_debug_handle (ep )
788823 m = ep .module ()
789- m_ref_logger = prepare_for_propagation_comparison (m ) # pyre-ignore[6]
824+ m_ref_logger = prepare_for_propagation_comparison (m )
790825
791826 quantizer = XNNPACKQuantizer ().set_global (
792827 get_symmetric_quantization_config (is_per_channel = False )
793828 )
794- m = prepare_pt2e (m , quantizer ) # pyre-ignore[6]
829+ m = prepare_pt2e (m , quantizer )
795830 m (* example_inputs )
796831 m = convert_pt2e (m )
797832 m_quant_logger = prepare_for_propagation_comparison (m )
@@ -800,29 +835,22 @@ def test_extract_results_from_loggers(self) -> None:
800835 m_quant_logger (* example_inputs )
801836 ref_results = extract_results_from_loggers (m_ref_logger )
802837 quant_results = extract_results_from_loggers (m_quant_logger )
803- comparison_results = compare_results (
804- ref_results ,
805- quant_results , # pyre-ignore[6]
806- )
838+ comparison_results = compare_results (ref_results , quant_results )
807839 for node_summary in comparison_results .values ():
808840 if len (node_summary .results ) > 0 :
809- self .assertGreaterEqual (
810- node_summary .results [0 ].sqnr ,
811- 35 , # pyre-ignore[6]
812- )
841+ self .assertGreaterEqual (node_summary .results [0 ].sqnr , 35 )
813842
814- def test_extract_results_from_loggers_list_output (self ) -> None :
843+ def test_extract_results_from_loggers_list_output (self ):
815844 m = TestHelperModules .Conv2dWithSplit ()
816845 example_inputs = m .example_inputs ()
817846 ep = export_for_training (m , example_inputs , strict = True )
818- generate_numeric_debug_handle (ep )
819847 m = ep .module ()
820- m_ref_logger = prepare_for_propagation_comparison (m ) # pyre-ignore[6]
848+ m_ref_logger = prepare_for_propagation_comparison (m )
821849
822850 quantizer = XNNPACKQuantizer ().set_global (
823851 get_symmetric_quantization_config (is_per_channel = False )
824852 )
825- m = prepare_pt2e (m , quantizer ) # pyre-ignore[6]
853+ m = prepare_pt2e (m , quantizer )
826854 m (* example_inputs )
827855 m = convert_pt2e (m )
828856 m_quant_logger = prepare_for_propagation_comparison (m )
@@ -831,15 +859,12 @@ def test_extract_results_from_loggers_list_output(self) -> None:
831859 m_quant_logger (* example_inputs )
832860 ref_results = extract_results_from_loggers (m_ref_logger )
833861 quant_results = extract_results_from_loggers (m_quant_logger )
834- comparison_results = compare_results (
835- ref_results ,
836- quant_results , # pyre-ignore[6]
837- )
862+ comparison_results = compare_results (ref_results , quant_results )
838863 for node_summary in comparison_results .values ():
839864 if len (node_summary .results ) > 0 :
840865 sqnr = node_summary .results [0 ].sqnr
841866 if isinstance (sqnr , list ):
842867 for sqnr_i in sqnr :
843868 self .assertGreaterEqual (sqnr_i , 35 )
844869 else :
845- self .assertGreaterEqual (sqnr , 35 ) # pyre-ignore[6]
870+ self .assertGreaterEqual (sqnr , 35 )
0 commit comments