1515from torch .testing ._internal .common_utils import IS_WINDOWS , TestCase , run_tests
1616
1717from torchao .quantization .pt2e import (
18- CUSTOM_KEY ,
19- NUMERIC_DEBUG_HANDLE_KEY ,
18+ FROM_NODE_KEY ,
2019 compare_results ,
2120 extract_results_from_loggers ,
22- generate_numeric_debug_handle ,
2321 prepare_for_propagation_comparison ,
2422)
23+ from torchao .quantization .pt2e ._numeric_debugger import _generate_debug_handle_from_node
2524from torchao .quantization .pt2e .graph_utils import bfs_trace_with_node_process
2625from torchao .quantization .pt2e .quantize_pt2e import convert_pt2e , prepare_pt2e
2726from torchao .testing .pt2e ._xnnpack_quantizer import (
3938class TestNumericDebugger (TestCase ):
4039 def _assert_each_node_has_debug_handle (self , model ) -> None :
4140 def _assert_node_has_debug_handle (node ):
42- self .assertTrue (
43- CUSTOM_KEY in node . meta
44- and NUMERIC_DEBUG_HANDLE_KEY in node .meta [ CUSTOM_KEY ] ,
45- f"Node { node } doesn't have debug handle " ,
41+ self .assertIn (
42+ FROM_NODE_KEY ,
43+ node .meta ,
44+ f"Node { node } doesn't have from_node info " ,
4645 )
4746
4847 bfs_trace_with_node_process (model , _assert_node_has_debug_handle )
@@ -52,13 +51,8 @@ def _extract_debug_handles(self, model) -> dict[str, int]:
5251
5352 def _extract_debug_handles_from_node (node ):
5453 nonlocal debug_handle_map
55- if (
56- CUSTOM_KEY in node .meta
57- and NUMERIC_DEBUG_HANDLE_KEY in node .meta [CUSTOM_KEY ]
58- ):
59- debug_handle_map [str (node )] = node .meta [CUSTOM_KEY ][
60- NUMERIC_DEBUG_HANDLE_KEY
61- ]
54+ if (dh := _generate_debug_handle_from_node (node )) is not None :
55+ debug_handle_map [str (node )] = dh
6256
6357 bfs_trace_with_node_process (model , _extract_debug_handles_from_node )
6458
@@ -69,12 +63,9 @@ def _extract_debug_handles_with_prev_decomp_op(self, model) -> dict[str, int]:
6963
7064 def _extract_debug_handles_with_prev_decomp_op_from_node (node ):
7165 nonlocal prev_decomp_op_to_debug_handle_map
72- if (
73- CUSTOM_KEY in node .meta
74- and NUMERIC_DEBUG_HANDLE_KEY in node .meta [CUSTOM_KEY ]
75- ):
66+ if FROM_NODE_KEY in node .meta :
7667 prev_decomp_op = str (node .meta .get ("nn_module_stack" ))
77- debug_handle = node . meta [ CUSTOM_KEY ][ NUMERIC_DEBUG_HANDLE_KEY ]
68+ debug_handle = _generate_debug_handle_from_node ( node )
7869 if prev_decomp_op not in prev_decomp_op_to_debug_handle_map :
7970 prev_decomp_op_to_debug_handle_map [prev_decomp_op ] = debug_handle
8071 else :
@@ -96,64 +87,73 @@ def test_simple(self):
9687 m = TestHelperModules .Conv2dThenConv1d ()
9788 example_inputs = m .example_inputs ()
9889 ep = export_for_training (m , example_inputs , strict = True )
99- generate_numeric_debug_handle ( ep )
100- self ._assert_each_node_has_debug_handle (ep )
101- debug_handle_map = self ._extract_debug_handles (ep )
90+ m = ep . module ( )
91+ self ._assert_each_node_has_debug_handle (m )
92+ debug_handle_map = self ._extract_debug_handles (m )
10293
10394 self .assertEqual (len (set (debug_handle_map .values ())), len (debug_handle_map ))
10495
96+ @unittest .skip ("debug flow not working on model with conditional control flow" )
10597 def test_control_flow (self ):
10698 m = TestHelperModules .ControlFlow ()
10799 example_inputs = m .example_inputs ()
108100 ep = export_for_training (m , example_inputs , strict = True )
109- generate_numeric_debug_handle ( ep )
101+ m = ep . module ( )
110102
111- self ._assert_each_node_has_debug_handle (ep )
112- debug_handle_map = self ._extract_debug_handles (ep )
103+ self ._assert_each_node_has_debug_handle (m )
104+ debug_handle_map = self ._extract_debug_handles (m )
113105
114106 self .assertEqual (len (set (debug_handle_map .values ())), len (debug_handle_map ))
115107
116108 def test_quantize_pt2e_preserve_handle (self ):
117109 m = TestHelperModules .Conv2dThenConv1d ()
118110 example_inputs = m .example_inputs ()
119111 ep = export_for_training (m , example_inputs , strict = True )
120- generate_numeric_debug_handle (ep )
121112 m = ep .module ()
122113
123114 quantizer = XNNPACKQuantizer ().set_global (
124115 get_symmetric_quantization_config (is_per_channel = False )
125116 )
126117 m = prepare_pt2e (m , quantizer )
127118 debug_handle_map = self ._extract_debug_handles (m )
119+ node_name_equip_with_output_observer = [
120+ "conv2d" ,
121+ "conv1d" ,
122+ "squeeze" ,
123+ ]
128124 res_counter = Counter (debug_handle_map .values ())
129- repeated_debug_handle_ids = [1 , 2 , 3 ]
125+ repeated_debug_handle_ids = [
126+ debug_handle_map [n_name ] for n_name in node_name_equip_with_output_observer
127+ ]
130128 # 3 ids were repeated because we copy over the id from node to its output observer
131129 # torch.ops.aten.conv2d.default, torch.ops.aten.squeeze.dim and torch.ops.aten.conv1d.default
132130 for dh_id in repeated_debug_handle_ids :
133131 self .assertEqual (res_counter [dh_id ], 2 )
134132
135133 m (* example_inputs )
136134 m = convert_pt2e (m )
137- self ._assert_each_node_has_debug_handle (ep )
135+ self ._assert_each_node_has_debug_handle (m )
138136 debug_handle_map = self ._extract_debug_handles (m )
139137 res_counter = Counter (debug_handle_map .values ())
140138 # same set of ids where repeated, because we copy over the id from observer/fake_quant to
141- # dequantize node
142- repeated_debug_handle_ids = [1 , 2 , 3 ]
139+ # quantize/dequantize node
140+ repeated_debug_handle_ids = [
141+ debug_handle_map [n_name ] for n_name in node_name_equip_with_output_observer
142+ ]
143143 for dh_id in repeated_debug_handle_ids :
144- self .assertEqual (res_counter [dh_id ], 2 )
144+ self .assertEqual (res_counter [dh_id ], 3 )
145145
146146 def test_copy_preserve_handle (self ):
147147 m = TestHelperModules .Conv2dThenConv1d ()
148148 example_inputs = m .example_inputs ()
149149 ep = torch .export .export (m , example_inputs , strict = True )
150- generate_numeric_debug_handle ( ep )
150+ m = ep . module ( )
151151
152- self ._assert_each_node_has_debug_handle (ep )
153- debug_handle_map_ref = self ._extract_debug_handles (ep )
152+ self ._assert_each_node_has_debug_handle (m )
153+ debug_handle_map_ref = self ._extract_debug_handles (m )
154154
155155 ep_copy = copy .copy (ep )
156- debug_handle_map = self ._extract_debug_handles (ep_copy )
156+ debug_handle_map = self ._extract_debug_handles (ep_copy . module () )
157157
158158 self ._assert_each_node_has_debug_handle (ep )
159159 self .assertEqual (debug_handle_map , debug_handle_map_ref )
@@ -162,13 +162,12 @@ def test_deepcopy_preserve_handle(self):
162162 m = TestHelperModules .Conv2dThenConv1d ()
163163 example_inputs = m .example_inputs ()
164164 ep = torch .export .export (m , example_inputs , strict = True )
165- generate_numeric_debug_handle (ep )
166165
167- debug_handle_map_ref = self ._extract_debug_handles (ep )
166+ debug_handle_map_ref = self ._extract_debug_handles (ep . module () )
168167 ep_copy = copy .deepcopy (ep )
169- debug_handle_map = self ._extract_debug_handles (ep_copy )
168+ debug_handle_map = self ._extract_debug_handles (ep_copy . module () )
170169
171- self ._assert_each_node_has_debug_handle (ep )
170+ self ._assert_each_node_has_debug_handle (ep . module () )
172171 self .assertEqual (debug_handle_map , debug_handle_map_ref )
173172
174173 @unittest .skip (
@@ -178,16 +177,16 @@ def test_re_export_preserve_handle(self):
178177 m = TestHelperModules .Conv2dThenConv1d ()
179178 example_inputs = m .example_inputs ()
180179 ep = export_for_training (m , example_inputs , strict = True )
181- generate_numeric_debug_handle (ep )
182180 m = ep .module ()
183181
184- self ._assert_each_node_has_debug_handle (ep )
185- debug_handle_map_ref = self ._extract_debug_handles (ep )
182+ self ._assert_each_node_has_debug_handle (m )
183+ debug_handle_map_ref = self ._extract_debug_handles (m )
186184
187185 ep_reexport = export_for_training (m , example_inputs , strict = True )
186+ m_reexport = ep_reexport .module ()
188187
189- self ._assert_each_node_has_debug_handle (ep_reexport )
190- debug_handle_map = self ._extract_debug_handles (ep_reexport )
188+ self ._assert_each_node_has_debug_handle (m_reexport )
189+ debug_handle_map = self ._extract_debug_handles (m_reexport )
191190
192191 self .assertEqual (debug_handle_map , debug_handle_map_ref )
193192
@@ -198,16 +197,17 @@ def test_run_decompositions_same_handle_id(self):
198197 m = TestHelperModules .Conv2dThenConv1d ()
199198 example_inputs = m .example_inputs ()
200199 ep = export_for_training (m , example_inputs , strict = True )
201- generate_numeric_debug_handle ( ep )
200+ m = ep . module ( )
202201
203- self ._assert_each_node_has_debug_handle (ep )
204- debug_handle_map_ref = self ._extract_debug_handles (ep )
202+ self ._assert_each_node_has_debug_handle (m )
203+ debug_handle_map_ref = self ._extract_debug_handles (m )
205204
206205 ep_copy = copy .copy (ep )
207206 ep_copy = ep_copy .run_decompositions ()
207+ m_decomposed = ep_copy .module ()
208208
209- self ._assert_each_node_has_debug_handle (ep_copy )
210- debug_handle_map = self ._extract_debug_handles (ep_copy )
209+ self ._assert_each_node_has_debug_handle (m_decomposed )
210+ debug_handle_map = self ._extract_debug_handles (m_decomposed )
211211
212212 # checking the map still has the same ids, the node may change
213213 self .assertEqual (
@@ -226,18 +226,19 @@ def test_run_decompositions_map_handle_to_new_nodes(self):
226226 for m in test_models :
227227 example_inputs = m .example_inputs ()
228228 ep = export_for_training (m , example_inputs , strict = True )
229- generate_numeric_debug_handle ( ep )
229+ m = ep . module ( )
230230
231- self ._assert_each_node_has_debug_handle (ep )
231+ self ._assert_each_node_has_debug_handle (m )
232232 pre_decomp_to_debug_handle_map_ref = (
233- self ._extract_debug_handles_with_prev_decomp_op (ep )
233+ self ._extract_debug_handles_with_prev_decomp_op (m )
234234 )
235235
236236 ep_copy = copy .copy (ep )
237237 ep_copy = ep_copy .run_decompositions ()
238- self ._assert_each_node_has_debug_handle (ep_copy )
238+ m_decomposed = ep_copy .module ()
239+ self ._assert_each_node_has_debug_handle (m_decomposed )
239240 pre_decomp_to_debug_handle_map = (
240- self ._extract_debug_handles_with_prev_decomp_op (ep_copy )
241+ self ._extract_debug_handles_with_prev_decomp_op (m_decomposed )
241242 )
242243
243244 # checking the map still has the same ids, the node may change
@@ -249,7 +250,6 @@ def test_prepare_for_propagation_comparison(self):
249250 m = TestHelperModules .Conv2dThenConv1d ()
250251 example_inputs = m .example_inputs ()
251252 ep = export_for_training (m , example_inputs , strict = True )
252- generate_numeric_debug_handle (ep )
253253 m = ep .module ()
254254 m_logger = prepare_for_propagation_comparison (m )
255255 ref = m (* example_inputs )
@@ -266,7 +266,6 @@ def test_extract_results_from_loggers(self):
266266 m = TestHelperModules .Conv2dThenConv1d ()
267267 example_inputs = m .example_inputs ()
268268 ep = export_for_training (m , example_inputs , strict = True )
269- generate_numeric_debug_handle (ep )
270269 m = ep .module ()
271270 m_ref_logger = prepare_for_propagation_comparison (m )
272271
@@ -291,7 +290,6 @@ def test_extract_results_from_loggers_list_output(self):
291290 m = TestHelperModules .Conv2dWithSplit ()
292291 example_inputs = m .example_inputs ()
293292 ep = export_for_training (m , example_inputs , strict = True )
294- generate_numeric_debug_handle (ep )
295293 m = ep .module ()
296294 m_ref_logger = prepare_for_propagation_comparison (m )
297295
@@ -321,9 +319,10 @@ def test_added_node_gets_unique_id(self) -> None:
321319 m = TestHelperModules .Conv2dThenConv1d ()
322320 example_inputs = m .example_inputs ()
323321 ep = export_for_training (m , example_inputs , strict = True )
324- generate_numeric_debug_handle ( ep )
325- ref_handles = self ._extract_debug_handles (ep )
322+
323+ ref_handles = self ._extract_debug_handles (ep . module () )
326324 ref_counter = Counter (ref_handles .values ())
325+
327326 for k , v in ref_counter .items ():
328327 self .assertEqual (
329328 v ,
@@ -345,10 +344,10 @@ def test_added_node_gets_unique_id(self) -> None:
345344
346345 # Regenerate handles, make sure only the new relu node has a new id, and
347346 # it doesn't clash with any of the existing ids.
348- generate_numeric_debug_handle (ep )
349347
350- self ._assert_each_node_has_debug_handle (ep )
351- handles_after_modification = self ._extract_debug_handles (ep )
348+ m = ep .module ()
349+ self ._assert_each_node_has_debug_handle (m )
350+ handles_after_modification = self ._extract_debug_handles (m )
352351 handles_counter = Counter (handles_after_modification .values ())
353352 for name , handle in ref_handles .items ():
354353 self .assertIn (name , handles_after_modification )
@@ -365,7 +364,7 @@ def test_added_node_gets_unique_id(self) -> None:
365364
366365 # Check for relu specifically. Avoid hardcoding the handle id since it
367366 # may change with future node ordering changes.
368- self .assertNotEqual (handles_after_modification ["relu_default" ], 0 )
367+ self .assertNotIn (handles_after_modification ["relu_default" ], ref_counter )
369368 self .assertEqual (handles_counter [handles_after_modification ["relu_default" ]], 1 )
370369
371370
0 commit comments