Skip to content

Commit fdef350

Browse files
Gasoonjiafacebook-github-bot
authored andcommitted
generate debug handle before opeartor decomposition (#11949)
Summary: This diff update the debug handle generation, from each node in the edge program having a individual debug handle, to all nodes having a same ancestor in export graph sharing a same debug handle, which update the start point of tracing our node transformation from edge graph to exported graph. Differential Revision: D76860368
1 parent 0c19a0b commit fdef350

File tree

7 files changed

+126
-56
lines changed

7 files changed

+126
-56
lines changed

devtools/inspector/_inspector.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -655,7 +655,7 @@ def _populate_debugging_related_fields(
655655

656656
def _associate_with_op_graph_nodes(
657657
self,
658-
debug_handle_to_op_node_map: Dict[int, OperatorNode],
658+
debug_handle_to_op_node_map: Dict[int, List[OperatorNode]],
659659
) -> None:
660660
"""
661661
Helper function to populate the stack_traces, module_hierarchy and op_types attributes
@@ -673,14 +673,21 @@ def _associate_with_op_graph_nodes(
673673
debug_handles = [debug_handles]
674674

675675
for handle in debug_handles:
676-
node = debug_handle_to_op_node_map.get(handle)
677-
# Attach node metadata including stack traces, module hierarchy and op_types to this event
678-
if node is not None and (metadata := node.metadata) is not None:
679-
self.stack_traces[node.name] = metadata.get("stack_trace")
680-
self.module_hierarchy[node.name] = metadata.get("nn_module_stack")
681-
if node.op:
682-
# TODO: consider having this as a dict from node.name -> node.op
683-
self.op_types += [node.op]
676+
nodes = debug_handle_to_op_node_map.get(handle, None)
677+
if nodes is None:
678+
continue
679+
680+
for node in nodes:
681+
# Attach node metadata including stack traces, module hierarchy and op_types to this event
682+
if node is not None and (metadata := node.metadata) is not None:
683+
if node.name not in self.stack_traces:
684+
self.stack_traces[node.name] = metadata.get("stack_trace")
685+
self.module_hierarchy[node.name] = metadata.get(
686+
"nn_module_stack"
687+
)
688+
if node.op:
689+
# TODO: consider having this as a dict from node.name -> node.op
690+
self.op_types += [node.op]
684691

685692

686693
@dataclass

devtools/inspector/_inspector_utils.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -301,14 +301,18 @@ def gen_graphs_from_etrecord(
301301
return op_graph_map
302302

303303

304+
# One debug handle should only be associated with one node. We are in the middle of migrating debug handle generation
305+
# from graph after to_edge to graph after torch.export, one every debug handle in exported graph may be associated with multiple nodes in to_edge
306+
# graph. After fully migration, we should bring the bring type as well as the #node check back.
307+
# TODO(gasoonjia): recover the return type to Dict[int, List[OperatorNode], reenable the #node check.
304308
def create_debug_handle_to_op_node_mapping(
305309
op_graph: OperatorGraph,
306-
) -> Dict[int, OperatorNode]:
310+
) -> Dict[int, List[OperatorNode]]:
307311
"""
308312
Recursive function to traverse all the operator graph nodes of input op_graph and build a mapping
309313
from each debug handle to the operator node that contains the debug handle in its metadata.
310314
"""
311-
debug_handle_to_op_node_map: Dict[int, OperatorNode] = {}
315+
debug_handle_to_op_node_map: Dict[int, List[OperatorNode]] = {}
312316

313317
# Recursively searches through the metadata of nodes
314318
def _extract_debug_handles(graph: OperatorGraph):
@@ -318,14 +322,13 @@ def _extract_debug_handles(graph: OperatorGraph):
318322
if isinstance(element, OperatorNode) and element.metadata is not None:
319323
metadata = element.metadata
320324
debug_handle = metadata.get("debug_handle")
321-
if debug_handle is not None:
322-
existing_entry = debug_handle_to_op_node_map.get(debug_handle)
323-
if existing_entry is not None:
324-
raise ValueError(
325-
f"Duplicated debug handle {str(debug_handle)} shared between {element.name} and {existing_entry.name}. "
326-
"No two op nodes of the same graph should have the same debug handle."
327-
)
328-
debug_handle_to_op_node_map[debug_handle] = element
325+
if debug_handle is None:
326+
continue
327+
328+
if debug_handle not in debug_handle_to_op_node_map:
329+
debug_handle_to_op_node_map[debug_handle] = []
330+
331+
debug_handle_to_op_node_map[debug_handle].append(element)
329332

330333
# Start traversing
331334
_extract_debug_handles(op_graph)

devtools/inspector/tests/inspector_test.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,11 @@ def test_inspector_associate_with_op_graph_nodes_single_debug_handle(self):
182182

183183
# Call the method that's under testing and verify
184184
event_with_single_debug_handle._associate_with_op_graph_nodes(
185-
{debug_handle: node_0}
185+
{
186+
debug_handle: [
187+
node_0,
188+
]
189+
}
186190
)
187191

188192
expected_stack_traces = {"node_0": "stack_trace_relu"}

devtools/inspector/tests/inspector_test_utils.py

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -62,25 +62,17 @@ def get_expected_intermediate_outputs():
6262
Returns the expected outputs of the debug handles and intermediate output mapping for this model for the given input.
6363
"""
6464
return {
65-
(10,): torch.tensor([[[[7.7000, 6.7000], [4.7000, 3.7000]]]]),
66-
(11,): torch.tensor([[7.7000, 6.7000, 4.7000, 3.7000]]),
67-
(12,): torch.tensor(
68-
[
69-
[0.1000, 0.5000],
70-
[0.2000, 0.6000],
71-
[0.3000, 0.7000],
72-
[0.4000, 0.8000],
73-
]
74-
),
75-
(13,): torch.tensor([[5.0000, 14.1200]]),
76-
(14,): torch.tensor([[5.5000, 13.6200]]),
77-
(15,): torch.tensor([[5.4000, 13.5200]]),
78-
(16,): torch.tensor([[10.8000, 6.7600]]),
79-
(17,): torch.tensor([3.0000, 1.5000]),
80-
(18,): torch.tensor([[3.6000, 4.5067]]),
81-
(19,): torch.tensor([[3.6000, 4.5067]]),
82-
(20,): torch.tensor([[0.9734, 0.9891]]),
83-
(21,): [torch.tensor([[0.9734]]), torch.tensor([[0.9891]])],
65+
(1,): torch.tensor([[[[7.7000, 6.7000], [4.7000, 3.7000]]]]),
66+
(2,): torch.tensor([[7.7000, 6.7000, 4.7000, 3.7000]]),
67+
(3,): torch.tensor([[5.0000, 14.1200]]),
68+
(4,): torch.tensor([[5.5000, 13.6200]]),
69+
(5,): torch.tensor([[5.4000, 13.5200]]),
70+
(6,): torch.tensor([[10.8000, 6.7600]]),
71+
(7,): torch.tensor([3.0000, 1.5000]),
72+
(8,): torch.tensor([[3.6000, 4.5067]]),
73+
(9,): torch.tensor([[3.6000, 4.5067]]),
74+
(10,): torch.tensor([[0.9734, 0.9891]]),
75+
(11,): [torch.tensor([[0.9734]]), torch.tensor([[0.9891]])],
8476
}
8577

8678

devtools/inspector/tests/inspector_utils_test.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,10 @@ def test_create_debug_handle_to_op_node_mapping(self):
7272
graph, expected_mapping = gen_mock_operator_graph_with_expected_map()
7373
debug_handle_to_op_node_map = create_debug_handle_to_op_node_mapping(graph)
7474

75+
print(debug_handle_to_op_node_map[111])
76+
print("+" * 100)
77+
print(expected_mapping[111])
78+
print("+" * 100)
7579
self.assertEqual(debug_handle_to_op_node_map, expected_mapping)
7680

7781
def test_find_populated_event(self):
@@ -489,7 +493,9 @@ def gen_mock_operator_graph_with_expected_map() -> (
489493
"nn_module_stack": "module_hierarchy_relu",
490494
},
491495
)
492-
mapping[111] = node_fused_conv_relu
496+
mapping[111] = [
497+
node_fused_conv_relu,
498+
]
493499
node_sin = OperatorNode(
494500
"sin",
495501
[node_fused_conv_relu],
@@ -500,7 +506,9 @@ def gen_mock_operator_graph_with_expected_map() -> (
500506
"nn_module_stack": "module_hierarchy_sin",
501507
},
502508
)
503-
mapping[222] = node_sin
509+
mapping[222] = [
510+
node_sin,
511+
]
504512
node_cos = OperatorNode(
505513
"cos",
506514
[node_sin],
@@ -511,7 +519,9 @@ def gen_mock_operator_graph_with_expected_map() -> (
511519
"nn_module_stack": "module_hierarchy_cos",
512520
},
513521
)
514-
mapping[333] = node_cos
522+
mapping[333] = [
523+
node_cos,
524+
]
515525
node_div = OperatorNode(
516526
"div",
517527
[node_cos],
@@ -522,7 +532,9 @@ def gen_mock_operator_graph_with_expected_map() -> (
522532
"nn_module_stack": "module_hierarchy_div",
523533
},
524534
)
525-
mapping[444] = node_div
535+
mapping[444] = [
536+
node_div,
537+
]
526538
node_output = ValueNode("output", [node_div])
527539
return (
528540
OperatorGraph(

exir/passes/debug_handle_generator_pass.py

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
from typing import Dict
8+
79
from executorch.exir.graph_module import bfs_trace_with_node_process
810
from executorch.exir.pass_base import ExportPass
911
from torch.export import ExportedProgram
10-
from torch.fx import GraphModule
12+
from torch.fx import GraphModule, Node
1113
from torch.fx.passes.infra.pass_base import PassResult
1214

1315

@@ -17,18 +19,57 @@ def call(self, graph_module: GraphModule) -> PassResult:
1719
to executorch backend, that has a canonical set of quantized operators
1820
"""
1921

20-
index = 1
22+
FROM_NODE_KEY = "from_node"
23+
DEBUG_HANDLE_KEY = "debug_handle"
24+
25+
source_node_to_debug_handle: Dict[str, int] = {}
26+
27+
def _get_greatest_ancestor_source_node(node: Node) -> str:
28+
"""Get the source of the greatest ancestor node of the given node. The source
29+
here means the name of the node concated with the id the graph it belongs to.
30+
For example, if the node transformation is node a -> b -> c, then the greatest
31+
ancestor node of c is a.
32+
"""
33+
34+
node_source = node.meta[FROM_NODE_KEY]
35+
node_source = node_source[-1]
36+
37+
while len(node_source.from_node) > 0:
38+
node_source = node_source.from_node[-1]
39+
40+
return node_source.name + str(node_source.graph_id)
41+
42+
def _extract_debug_handles_from_node(node: Node) -> None:
43+
"""
44+
Generate a debug handle based on node's oldest ancestor node's name
45+
and graph id, or return None if the node does not need to be traced.
46+
"""
47+
48+
if node.op == "placeholder" or node.op == "output":
49+
# placeholder and output nodes don't have debug handle
50+
return
51+
52+
assert (
53+
FROM_NODE_KEY in node.meta
54+
), f"Node {node} does not have meta key {FROM_NODE_KEY}"
55+
56+
source_node = _get_greatest_ancestor_source_node(node)
57+
58+
debug_handle = (
59+
len(source_node_to_debug_handle) + 1
60+
if source_node not in source_node_to_debug_handle
61+
else source_node_to_debug_handle[source_node]
62+
)
63+
source_node_to_debug_handle[source_node] = debug_handle
2164

22-
def _extract_debug_handles_from_node(node):
23-
nonlocal index
24-
node.meta["debug_handle"] = index
25-
index += 1
65+
node.meta[DEBUG_HANDLE_KEY] = debug_handle
2666

2767
bfs_trace_with_node_process(graph_module, _extract_debug_handles_from_node)
2868

2969
return PassResult(graph_module, True)
3070

3171

72+
# TODO(gasoonjia): generate missing debug handles using `from_node` info
3273
def generate_missing_debug_handles(ep: ExportedProgram):
3374
"""
3475
This pass is used to generate missing debug handles for the graph module and its submodules.

exir/tests/test_passes.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -859,22 +859,32 @@ def test_debug_handle_generator_pass(self) -> None:
859859
.exported_program()
860860
.graph_module
861861
)
862+
863+
# Every node except input and output should have debug handle
862864
for node in graph_module.graph.nodes:
863-
self.assertIn("debug_handle", node.meta)
865+
if node.op != "placeholder" and node.op != "output":
866+
self.assertIn("debug_handle", node.meta)
864867
ScalarToTensorPass()(graph_module)
868+
865869
for node in graph_module.graph.nodes:
866-
self.assertIn("debug_handle", node.meta)
870+
if node.op != "placeholder" and node.op != "output":
871+
self.assertIn("debug_handle", node.meta)
867872

868873
def test_generate_missing_debug_handles(self) -> None:
869874
eager_model = MLP(2, output_size=4)
870875
inputs = eager_model.get_random_inputs()
871876

872877
ep = to_edge(export(eager_model, inputs, strict=True)).exported_program()
873878

874-
list(ep.graph.nodes)[0].meta.pop("debug_handle")
875-
self.assertTrue(list(ep.graph.nodes)[0].meta.get("debug_handle") is None)
879+
# get the first non-placeholder node
880+
first_non_placeholder_node = [
881+
n for n in ep.graph.nodes if n.op != "placeholder"
882+
][0]
883+
884+
first_non_placeholder_node.meta.pop("debug_handle")
885+
self.assertTrue(first_non_placeholder_node.meta.get("debug_handle") is None)
876886
generate_missing_debug_handles(ep)
877-
self.assertTrue(list(ep.graph.nodes)[0].meta.get("debug_handle") is not None)
887+
self.assertTrue(first_non_placeholder_node.meta.get("debug_handle") is not None)
878888

879889
def test_debug_handle_generator_pass_with_control_flow(self) -> None:
880890
def true_nested(y: torch.Tensor) -> torch.Tensor:
@@ -928,7 +938,8 @@ def check_debug_handle_metadata(graph_module: torch.fx.GraphModule) -> None:
928938
while queue:
929939
current_graph_module = queue.pop(0)
930940
for node in current_graph_module.graph.nodes:
931-
self.assertIn("debug_handle", node.meta)
941+
if node.op != "placeholder" and node.op != "output":
942+
self.assertIn("debug_handle", node.meta)
932943
control_flow_submodules = [
933944
submodule
934945
for _, submodule, _ in get_control_flow_submodules(
@@ -939,7 +950,7 @@ def check_debug_handle_metadata(graph_module: torch.fx.GraphModule) -> None:
939950

940951
DebugHandleGeneratorPass()(graph_module)
941952
check_debug_handle_metadata(graph_module)
942-
generate_missing_debug_handles(ep)
953+
# generate_missing_debug_handles(ep)
943954

944955
# Check debug handle still preserved after ScalarToTensorPass
945956
ScalarToTensorPass()(graph_module)

0 commit comments

Comments
 (0)