Skip to content

Commit 5eb14b7

Browse files
[et] generate debug handle before opeartor decomposition (#12275)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #11997 by @Gasoonjia ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/gasoonjia/17/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/gasoonjia/17/head Merge bot PR base: https://github.com/pytorch/executorch/tree/main Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/gasoonjia/17/orig @diff-train-skip-merge Co-authored-by: gasoonjia <[email protected]>
1 parent f9a3ca8 commit 5eb14b7

14 files changed

+266
-96
lines changed

devtools/inspector/_inspector.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -658,7 +658,7 @@ def _populate_debugging_related_fields(
658658

659659
def _associate_with_op_graph_nodes(
660660
self,
661-
debug_handle_to_op_node_map: Dict[int, OperatorNode],
661+
debug_handle_to_op_node_map: Dict[int, List[OperatorNode]],
662662
) -> None:
663663
"""
664664
Helper function to populate the stack_traces, module_hierarchy and op_types attributes
@@ -676,14 +676,21 @@ def _associate_with_op_graph_nodes(
676676
debug_handles = [debug_handles]
677677

678678
for handle in debug_handles:
679-
node = debug_handle_to_op_node_map.get(handle)
680-
# Attach node metadata including stack traces, module hierarchy and op_types to this event
681-
if node is not None and (metadata := node.metadata) is not None:
682-
self.stack_traces[node.name] = metadata.get("stack_trace")
683-
self.module_hierarchy[node.name] = metadata.get("nn_module_stack")
684-
if node.op:
685-
# TODO: consider having this as a dict from node.name -> node.op
686-
self.op_types += [node.op]
679+
nodes = debug_handle_to_op_node_map.get(handle, None)
680+
if nodes is None:
681+
continue
682+
683+
for node in nodes:
684+
# Attach node metadata including stack traces, module hierarchy and op_types to this event
685+
if node is not None and (metadata := node.metadata) is not None:
686+
if node.name not in self.stack_traces:
687+
self.stack_traces[node.name] = metadata.get("stack_trace")
688+
self.module_hierarchy[node.name] = metadata.get(
689+
"nn_module_stack"
690+
)
691+
if node.op:
692+
# TODO: consider having this as a dict from node.name -> node.op
693+
self.op_types += [node.op]
687694

688695

689696
@dataclass

devtools/inspector/_inspector_utils.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -303,14 +303,23 @@ def gen_graphs_from_etrecord(
303303
return op_graph_map
304304

305305

306+
# One debug handle should only be associated with one node. We are in the middle of migrating debug handle generation
307+
# from graph after to_edge to graph after torch.export, one every debug handle in exported graph may be associated with multiple nodes in to_edge
308+
# graph. After fully migration, we should bring the bring type as well as the #node check back.
309+
#
310+
# Before migration: returned Dict for 1 debug handle to 1 node in to_edge graph
311+
# During migration: returned Dict for 1 debug handle to multiple nodes in to_edge graph
312+
# After migration: returned Dict for 1 debug handle to 1 node in exported graph
313+
#
314+
# TODO(gasoonjia): recover the return type to Dict[int, List[OperatorNode], reenable the #node check.
306315
def create_debug_handle_to_op_node_mapping(
307316
op_graph: OperatorGraph,
308-
) -> Dict[int, OperatorNode]:
317+
) -> Dict[int, List[OperatorNode]]:
309318
"""
310319
Recursive function to traverse all the operator graph nodes of input op_graph and build a mapping
311320
from each debug handle to the operator node that contains the debug handle in its metadata.
312321
"""
313-
debug_handle_to_op_node_map: Dict[int, OperatorNode] = {}
322+
debug_handle_to_op_node_map: Dict[int, List[OperatorNode]] = {}
314323

315324
# Recursively searches through the metadata of nodes
316325
def _extract_debug_handles(graph: OperatorGraph):
@@ -320,14 +329,13 @@ def _extract_debug_handles(graph: OperatorGraph):
320329
if isinstance(element, OperatorNode) and element.metadata is not None:
321330
metadata = element.metadata
322331
debug_handle = metadata.get("debug_handle")
323-
if debug_handle is not None:
324-
existing_entry = debug_handle_to_op_node_map.get(debug_handle)
325-
if existing_entry is not None:
326-
raise ValueError(
327-
f"Duplicated debug handle {str(debug_handle)} shared between {element.name} and {existing_entry.name}. "
328-
"No two op nodes of the same graph should have the same debug handle."
329-
)
330-
debug_handle_to_op_node_map[debug_handle] = element
332+
if debug_handle is None:
333+
continue
334+
335+
if debug_handle not in debug_handle_to_op_node_map:
336+
debug_handle_to_op_node_map[debug_handle] = []
337+
338+
debug_handle_to_op_node_map[debug_handle].append(element)
331339

332340
# Start traversing
333341
_extract_debug_handles(op_graph)

devtools/inspector/tests/inspector_test.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,11 @@ def test_inspector_associate_with_op_graph_nodes_single_debug_handle(self):
183183

184184
# Call the method that's under testing and verify
185185
event_with_single_debug_handle._associate_with_op_graph_nodes(
186-
{debug_handle: node_0}
186+
{
187+
debug_handle: [
188+
node_0,
189+
]
190+
}
187191
)
188192

189193
expected_stack_traces = {"node_0": "stack_trace_relu"}
@@ -226,7 +230,14 @@ def test_inspector_associate_with_op_graph_nodes_multiple_debug_handles(self):
226230

227231
# Call the method that's under testing and verify
228232
event_with_multiple_debug_handles._associate_with_op_graph_nodes(
229-
{debug_handles[0]: node_0, debug_handles[1]: node_1}
233+
{
234+
debug_handles[0]: [
235+
node_0,
236+
],
237+
debug_handles[1]: [
238+
node_1,
239+
],
240+
}
230241
)
231242

232243
expected_stack_traces = {

devtools/inspector/tests/inspector_test_utils.py

Lines changed: 22 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -62,25 +62,17 @@ def get_expected_intermediate_outputs():
6262
Returns the expected outputs of the debug handles and intermediate output mapping for this model for the given input.
6363
"""
6464
return {
65-
(10,): torch.tensor([[[[7.7000, 6.7000], [4.7000, 3.7000]]]]),
66-
(11,): torch.tensor([[7.7000, 6.7000, 4.7000, 3.7000]]),
67-
(12,): torch.tensor(
68-
[
69-
[0.1000, 0.5000],
70-
[0.2000, 0.6000],
71-
[0.3000, 0.7000],
72-
[0.4000, 0.8000],
73-
]
74-
),
75-
(13,): torch.tensor([[5.0000, 14.1200]]),
76-
(14,): torch.tensor([[5.5000, 13.6200]]),
77-
(15,): torch.tensor([[5.4000, 13.5200]]),
78-
(16,): torch.tensor([[10.8000, 6.7600]]),
79-
(17,): torch.tensor([3.0000, 1.5000]),
80-
(18,): torch.tensor([[3.6000, 4.5067]]),
81-
(19,): torch.tensor([[3.6000, 4.5067]]),
82-
(20,): torch.tensor([[0.9734, 0.9891]]),
83-
(21,): [torch.tensor([[0.9734]]), torch.tensor([[0.9891]])],
65+
(1,): torch.tensor([[[[7.7000, 6.7000], [4.7000, 3.7000]]]]),
66+
(2,): torch.tensor([[7.7000, 6.7000, 4.7000, 3.7000]]),
67+
(3,): torch.tensor([[5.0000, 14.1200]]),
68+
(4,): torch.tensor([[5.5000, 13.6200]]),
69+
(5,): torch.tensor([[5.4000, 13.5200]]),
70+
(6,): torch.tensor([[10.8000, 6.7600]]),
71+
(7,): torch.tensor([3.0000, 1.5000]),
72+
(8,): torch.tensor([[3.6000, 4.5067]]),
73+
(9,): torch.tensor([[3.6000, 4.5067]]),
74+
(10,): torch.tensor([[0.9734, 0.9891]]),
75+
(11,): [torch.tensor([[0.9734]]), torch.tensor([[0.9891]])],
8476
}
8577

8678
@staticmethod
@@ -89,18 +81,17 @@ def get_expected_debug_handle_to_op_name():
8981
Returns the expected debug handle and op name mapping for this model for the given input.
9082
"""
9183
return {
92-
(10,): "aten_convolution_default",
93-
(11,): "aten_view_copy_default",
94-
(12,): "aten_permute_copy_default",
95-
(13,): "aten_addmm_default",
96-
(14,): "aten_add_tensor",
97-
(15,): "aten_sub_tensor",
98-
(16,): "aten_mul_tensor",
99-
(17,): "aten_add_tensor_1",
100-
(18,): "aten_div_tensor",
101-
(19,): "aten_relu_default",
102-
(20,): "aten_sigmoid_default",
103-
(21,): "aten_split_with_sizes_copy_default",
84+
(1,): "aten_convolution_default",
85+
(2,): "aten_view_copy_default",
86+
(3,): "aten_addmm_default",
87+
(4,): "aten_add_tensor",
88+
(5,): "aten_sub_tensor",
89+
(6,): "aten_mul_tensor",
90+
(7,): "aten_add_tensor_1",
91+
(8,): "aten_div_tensor",
92+
(9,): "aten_relu_default",
93+
(10,): "aten_sigmoid_default",
94+
(11,): "aten_split_with_sizes_copy_default",
10495
}
10596

10697

devtools/inspector/tests/inspector_utils_test.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -583,7 +583,9 @@ def gen_mock_operator_graph_with_expected_map() -> (
583583
"nn_module_stack": "module_hierarchy_relu",
584584
},
585585
)
586-
mapping[111] = node_fused_conv_relu
586+
mapping[111] = [
587+
node_fused_conv_relu,
588+
]
587589
node_sin = OperatorNode(
588590
"sin",
589591
[node_fused_conv_relu],
@@ -594,7 +596,9 @@ def gen_mock_operator_graph_with_expected_map() -> (
594596
"nn_module_stack": "module_hierarchy_sin",
595597
},
596598
)
597-
mapping[222] = node_sin
599+
mapping[222] = [
600+
node_sin,
601+
]
598602
node_cos = OperatorNode(
599603
"cos",
600604
[node_sin],
@@ -605,7 +609,9 @@ def gen_mock_operator_graph_with_expected_map() -> (
605609
"nn_module_stack": "module_hierarchy_cos",
606610
},
607611
)
608-
mapping[333] = node_cos
612+
mapping[333] = [
613+
node_cos,
614+
]
609615
node_div = OperatorNode(
610616
"div",
611617
[node_cos],
@@ -616,7 +622,9 @@ def gen_mock_operator_graph_with_expected_map() -> (
616622
"nn_module_stack": "module_hierarchy_div",
617623
},
618624
)
619-
mapping[444] = node_div
625+
mapping[444] = [
626+
node_div,
627+
]
620628
node_output = ValueNode("output", [node_div])
621629
return (
622630
OperatorGraph(

exir/backend/test/qnn_backend_demo.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@ def preprocess(
2424
) -> PreprocessResult:
2525
processed_bytes = "imqnncompiled"
2626
all_nodes_debug_handle = [
27-
node.meta["debug_handle"] for node in edge_program.graph.nodes
27+
node.meta["debug_handle"]
28+
for node in edge_program.graph.nodes
29+
if node.op not in ("placeholder", "output")
2830
]
2931
return PreprocessResult(
3032
processed_bytes=bytes(processed_bytes, encoding="utf8"),

exir/backend/test/test_backends.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ def forward(self, x):
194194
program=program,
195195
delegate=program.execution_plan[0].delegates[0],
196196
expected_id=BackendWithCompilerDemo.__name__,
197-
expected_processed=b"1version:0#op:demo::aten.sin.default, numel:1, dtype:torch.float32<debug_handle>2#",
197+
expected_processed=b"1version:0#op:demo::aten.sin.default, numel:1, dtype:torch.float32<debug_handle>1#",
198198
)
199199

200200
# Check the delegate instruction
@@ -414,7 +414,7 @@ def forward(self, x):
414414
program=program,
415415
delegate=program.execution_plan[0].delegates[0],
416416
expected_id=BackendWithCompilerDemo.__name__,
417-
expected_processed=b"1version:0#op:demo::aten.sin.default, numel:1, dtype:torch.float32<debug_handle>2#",
417+
expected_processed=b"1version:0#op:demo::aten.sin.default, numel:1, dtype:torch.float32<debug_handle>1#",
418418
)
419419

420420
# Check the delegate instruction
@@ -1320,7 +1320,7 @@ def forward(self, x):
13201320
program=program,
13211321
delegate=program.execution_plan[0].delegates[0],
13221322
expected_id=BackendWithCompilerDemo.__name__,
1323-
expected_processed=b"1version:0#op:demo::aten.sin.default, numel:1, dtype:torch.float32<debug_handle>2#",
1323+
expected_processed=b"1version:0#op:demo::aten.sin.default, numel:1, dtype:torch.float32<debug_handle>1#",
13241324
)
13251325

13261326
# Check the delegate instruction

exir/backend/test/test_backends_lifted.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ def forward(self, x):
227227
program=program,
228228
delegate=program.execution_plan[0].delegates[0],
229229
expected_id=BackendWithCompilerDemo.__name__,
230-
expected_processed=b"1version:0#op:demo::aten.sin.default, numel:1, dtype:torch.float32<debug_handle>2#",
230+
expected_processed=b"1version:0#op:demo::aten.sin.default, numel:1, dtype:torch.float32<debug_handle>1#",
231231
)
232232

233233
# Check the delegate instruction
@@ -437,7 +437,7 @@ def forward(self, x):
437437
program=program,
438438
delegate=program.execution_plan[0].delegates[0],
439439
expected_id=BackendWithCompilerDemo.__name__,
440-
expected_processed=b"1version:0#op:demo::aten.sin.default, numel:1, dtype:torch.float32<debug_handle>2#",
440+
expected_processed=b"1version:0#op:demo::aten.sin.default, numel:1, dtype:torch.float32<debug_handle>1#",
441441
)
442442

443443
# Check the delegate instruction

exir/backend/test/test_debug_handle_map.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,13 @@ def test_lowered_the_whole_model(self, unlift):
9797
all_debug_handles = list(lowered_model.meta["debug_handle_map"].values())[0]
9898
self.assertEqual(
9999
len(all_debug_handles),
100-
len(lowered_model.original_module.graph.nodes),
100+
len(
101+
[
102+
node
103+
for node in lowered_model.original_module.graph.nodes
104+
if node.op not in ("placeholder", "output")
105+
]
106+
),
101107
)
102108

103109
class ComposedModel(torch.nn.Module):
@@ -127,5 +133,11 @@ def forward(self, *args):
127133
)[0]
128134
self.assertEqual(
129135
len(all_debug_handles),
130-
len(lowered_node.original_module.graph.nodes),
136+
len(
137+
[
138+
node
139+
for node in lowered_node.original_module.graph.nodes
140+
if node.op not in ("placeholder", "output")
141+
]
142+
),
131143
)

0 commit comments

Comments
 (0)