Skip to content

Commit 3a66602

Browse files
committed
[et] generate debug handle before opeartor decomposition
This diff update the debug handle generation, from each node in the edge program having a individual debug handle, to all nodes having a same ancestor in export graph sharing a same debug handle, which update the start point of tracing our node transformation from edge graph to exported graph. Differential Revision: [D76860368](https://our.internmc.facebook.com/intern/diff/D76860368/) [ghstack-poisoned]
1 parent 85cf6ce commit 3a66602

File tree

7 files changed

+121
-56
lines changed

7 files changed

+121
-56
lines changed

devtools/inspector/_inspector.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -654,7 +654,7 @@ def _populate_debugging_related_fields(
654654

655655
def _associate_with_op_graph_nodes(
656656
self,
657-
debug_handle_to_op_node_map: Dict[int, OperatorNode],
657+
debug_handle_to_op_node_map: Dict[int, List[OperatorNode]],
658658
) -> None:
659659
"""
660660
Helper function to populate the stack_traces, module_hierarchy and op_types attributes
@@ -672,14 +672,21 @@ def _associate_with_op_graph_nodes(
672672
debug_handles = [debug_handles]
673673

674674
for handle in debug_handles:
675-
node = debug_handle_to_op_node_map.get(handle)
676-
# Attach node metadata including stack traces, module hierarchy and op_types to this event
677-
if node is not None and (metadata := node.metadata) is not None:
678-
self.stack_traces[node.name] = metadata.get("stack_trace")
679-
self.module_hierarchy[node.name] = metadata.get("nn_module_stack")
680-
if node.op:
681-
# TODO: consider having this as a dict from node.name -> node.op
682-
self.op_types += [node.op]
675+
nodes = debug_handle_to_op_node_map.get(handle, None)
676+
if nodes is None:
677+
continue
678+
679+
for node in nodes:
680+
# Attach node metadata including stack traces, module hierarchy and op_types to this event
681+
if node is not None and (metadata := node.metadata) is not None:
682+
if node.name not in self.stack_traces:
683+
self.stack_traces[node.name] = metadata.get("stack_trace")
684+
self.module_hierarchy[node.name] = metadata.get(
685+
"nn_module_stack"
686+
)
687+
if node.op:
688+
# TODO: consider having this as a dict from node.name -> node.op
689+
self.op_types += [node.op]
683690

684691

685692
@dataclass

devtools/inspector/_inspector_utils.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -279,14 +279,18 @@ def gen_graphs_from_etrecord(
279279
return op_graph_map
280280

281281

282+
# One debug handle should only be associated with one node. We are in the middle of migrating debug handle generation
283+
# from graph after to_edge to graph after torch.export, one every debug handle in exported graph may be associated with multiple nodes in to_edge
284+
# graph. After fully migration, we should bring the bring type as well as the #node check back.
285+
# TODO(gasoonjia): recover the return type to Dict[int, List[OperatorNode], reenable the #node check.
282286
def create_debug_handle_to_op_node_mapping(
283287
op_graph: OperatorGraph,
284-
) -> Dict[int, OperatorNode]:
288+
) -> Dict[int, List[OperatorNode]]:
285289
"""
286290
Recursive function to traverse all the operator graph nodes of input op_graph and build a mapping
287291
from each debug handle to the operator node that contains the debug handle in its metadata.
288292
"""
289-
debug_handle_to_op_node_map: Dict[int, OperatorNode] = {}
293+
debug_handle_to_op_node_map: Dict[int, List[OperatorNode]] = {}
290294

291295
# Recursively searches through the metadata of nodes
292296
def _extract_debug_handles(graph: OperatorGraph):
@@ -296,14 +300,13 @@ def _extract_debug_handles(graph: OperatorGraph):
296300
if isinstance(element, OperatorNode) and element.metadata is not None:
297301
metadata = element.metadata
298302
debug_handle = metadata.get("debug_handle")
299-
if debug_handle is not None:
300-
existing_entry = debug_handle_to_op_node_map.get(debug_handle)
301-
if existing_entry is not None:
302-
raise ValueError(
303-
f"Duplicated debug handle {str(debug_handle)} shared between {element.name} and {existing_entry.name}. "
304-
"No two op nodes of the same graph should have the same debug handle."
305-
)
306-
debug_handle_to_op_node_map[debug_handle] = element
303+
if debug_handle is None:
304+
continue
305+
306+
if debug_handle not in debug_handle_to_op_node_map:
307+
debug_handle_to_op_node_map[debug_handle] = []
308+
309+
debug_handle_to_op_node_map[debug_handle].append(element)
307310

308311
# Start traversing
309312
_extract_debug_handles(op_graph)

devtools/inspector/tests/inspector_test.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,11 @@ def test_inspector_associate_with_op_graph_nodes_single_debug_handle(self):
182182

183183
# Call the method that's under testing and verify
184184
event_with_single_debug_handle._associate_with_op_graph_nodes(
185-
{debug_handle: node_0}
185+
{
186+
debug_handle: [
187+
node_0,
188+
]
189+
}
186190
)
187191

188192
expected_stack_traces = {"node_0": "stack_trace_relu"}

devtools/inspector/tests/inspector_test_utils.py

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -62,25 +62,17 @@ def get_expected_intermediate_outputs():
6262
Returns the expected outputs of the debug handles and intermediate output mapping for this model for the given input.
6363
"""
6464
return {
65-
(10,): torch.tensor([[[[7.7000, 6.7000], [4.7000, 3.7000]]]]),
66-
(11,): torch.tensor([[7.7000, 6.7000, 4.7000, 3.7000]]),
67-
(12,): torch.tensor(
68-
[
69-
[0.1000, 0.5000],
70-
[0.2000, 0.6000],
71-
[0.3000, 0.7000],
72-
[0.4000, 0.8000],
73-
]
74-
),
75-
(13,): torch.tensor([[5.0000, 14.1200]]),
76-
(14,): torch.tensor([[5.5000, 13.6200]]),
77-
(15,): torch.tensor([[5.4000, 13.5200]]),
78-
(16,): torch.tensor([[10.8000, 6.7600]]),
79-
(17,): torch.tensor([3.0000, 1.5000]),
80-
(18,): torch.tensor([[3.6000, 4.5067]]),
81-
(19,): torch.tensor([[3.6000, 4.5067]]),
82-
(20,): torch.tensor([[0.9734, 0.9891]]),
83-
(21,): [torch.tensor([[0.9734]]), torch.tensor([[0.9891]])],
65+
(1,): torch.tensor([[[[7.7000, 6.7000], [4.7000, 3.7000]]]]),
66+
(2,): torch.tensor([[7.7000, 6.7000, 4.7000, 3.7000]]),
67+
(3,): torch.tensor([[5.0000, 14.1200]]),
68+
(4,): torch.tensor([[5.5000, 13.6200]]),
69+
(5,): torch.tensor([[5.4000, 13.5200]]),
70+
(6,): torch.tensor([[10.8000, 6.7600]]),
71+
(7,): torch.tensor([3.0000, 1.5000]),
72+
(8,): torch.tensor([[3.6000, 4.5067]]),
73+
(9,): torch.tensor([[3.6000, 4.5067]]),
74+
(10,): torch.tensor([[0.9734, 0.9891]]),
75+
(11,): [torch.tensor([[0.9734]]), torch.tensor([[0.9891]])],
8476
}
8577

8678

devtools/inspector/tests/inspector_utils_test.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,9 @@ def gen_mock_operator_graph_with_expected_map() -> (
381381
"nn_module_stack": "module_hierarchy_relu",
382382
},
383383
)
384-
mapping[111] = node_fused_conv_relu
384+
mapping[111] = [
385+
node_fused_conv_relu,
386+
]
385387
node_sin = OperatorNode(
386388
"sin",
387389
[node_fused_conv_relu],
@@ -392,7 +394,9 @@ def gen_mock_operator_graph_with_expected_map() -> (
392394
"nn_module_stack": "module_hierarchy_sin",
393395
},
394396
)
395-
mapping[222] = node_sin
397+
mapping[222] = [
398+
node_sin,
399+
]
396400
node_cos = OperatorNode(
397401
"cos",
398402
[node_sin],
@@ -403,7 +407,9 @@ def gen_mock_operator_graph_with_expected_map() -> (
403407
"nn_module_stack": "module_hierarchy_cos",
404408
},
405409
)
406-
mapping[333] = node_cos
410+
mapping[333] = [
411+
node_cos,
412+
]
407413
node_div = OperatorNode(
408414
"div",
409415
[node_cos],
@@ -414,7 +420,9 @@ def gen_mock_operator_graph_with_expected_map() -> (
414420
"nn_module_stack": "module_hierarchy_div",
415421
},
416422
)
417-
mapping[444] = node_div
423+
mapping[444] = [
424+
node_div,
425+
]
418426
node_output = ValueNode("output", [node_div])
419427
return (
420428
OperatorGraph(

exir/passes/debug_handle_generator_pass.py

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
from typing import Dict
8+
79
from executorch.exir.graph_module import bfs_trace_with_node_process
810
from executorch.exir.pass_base import ExportPass
911
from torch.export import ExportedProgram
10-
from torch.fx import GraphModule
12+
from torch.fx import GraphModule, Node
1113
from torch.fx.passes.infra.pass_base import PassResult
1214

1315

@@ -17,18 +19,57 @@ def call(self, graph_module: GraphModule) -> PassResult:
1719
to executorch backend, that has a canonical set of quantized operators
1820
"""
1921

20-
index = 1
22+
FROM_NODE_KEY = "from_node"
23+
DEBUG_HANDLE_KEY = "debug_handle"
24+
25+
source_node_to_debug_handle: Dict[str, int] = {}
26+
27+
def _get_greatest_ancestor_source_node(node: Node) -> str:
28+
"""Get the source of the greatest ancestor node of the given node. The source
29+
here means the name of the node concated with the id the graph it belongs to.
30+
For example, if the node transformation is node a -> b -> c, then the greatest
31+
ancestor node of c is a.
32+
"""
33+
34+
node_source = node.meta[FROM_NODE_KEY]
35+
node_source = node_source[-1]
36+
37+
while len(node_source.from_node) > 0:
38+
node_source = node_source.from_node[-1]
39+
40+
return node_source.name + str(node_source.graph_id)
41+
42+
def _extract_debug_handles_from_node(node: Node) -> None:
43+
"""
44+
Generate a debug handle based on node's oldest ancestor node's name
45+
and graph id, or return None if the node does not need to be traced.
46+
"""
47+
48+
if node.op == "placeholder" or node.op == "output":
49+
# placeholder and output nodes don't have debug handle
50+
return
51+
52+
assert (
53+
FROM_NODE_KEY in node.meta
54+
), f"Node {node} does not have meta key {FROM_NODE_KEY}"
55+
56+
source_node = _get_greatest_ancestor_source_node(node)
57+
58+
debug_handle = (
59+
len(source_node_to_debug_handle) + 1
60+
if source_node not in source_node_to_debug_handle
61+
else source_node_to_debug_handle[source_node]
62+
)
63+
source_node_to_debug_handle[source_node] = debug_handle
2164

22-
def _extract_debug_handles_from_node(node):
23-
nonlocal index
24-
node.meta["debug_handle"] = index
25-
index += 1
65+
node.meta[DEBUG_HANDLE_KEY] = debug_handle
2666

2767
bfs_trace_with_node_process(graph_module, _extract_debug_handles_from_node)
2868

2969
return PassResult(graph_module, True)
3070

3171

72+
# TODO(gasoonjia): generate missing debug handles using `from_node` info
3273
def generate_missing_debug_handles(ep: ExportedProgram):
3374
"""
3475
This pass is used to generate missing debug handles for the graph module and its submodules.

exir/tests/test_passes.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -859,22 +859,32 @@ def test_debug_handle_generator_pass(self) -> None:
859859
.exported_program()
860860
.graph_module
861861
)
862+
863+
# Every node except input and output should have debug handle
862864
for node in graph_module.graph.nodes:
863-
self.assertIn("debug_handle", node.meta)
865+
if node.op != "placeholder" and node.op != "output":
866+
self.assertIn("debug_handle", node.meta)
864867
ScalarToTensorPass()(graph_module)
868+
865869
for node in graph_module.graph.nodes:
866-
self.assertIn("debug_handle", node.meta)
870+
if node.op != "placeholder" and node.op != "output":
871+
self.assertIn("debug_handle", node.meta)
867872

868873
def test_generate_missing_debug_handles(self) -> None:
869874
eager_model = MLP(2, output_size=4)
870875
inputs = eager_model.get_random_inputs()
871876

872877
ep = to_edge(export(eager_model, inputs, strict=True)).exported_program()
873878

874-
list(ep.graph.nodes)[0].meta.pop("debug_handle")
875-
self.assertTrue(list(ep.graph.nodes)[0].meta.get("debug_handle") is None)
879+
# get the first non-placeholder node
880+
first_non_placeholder_node = [
881+
n for n in ep.graph.nodes if n.op != "placeholder"
882+
][0]
883+
884+
first_non_placeholder_node.meta.pop("debug_handle")
885+
self.assertTrue(first_non_placeholder_node.meta.get("debug_handle") is None)
876886
generate_missing_debug_handles(ep)
877-
self.assertTrue(list(ep.graph.nodes)[0].meta.get("debug_handle") is not None)
887+
self.assertTrue(first_non_placeholder_node.meta.get("debug_handle") is not None)
878888

879889
def test_debug_handle_generator_pass_with_control_flow(self) -> None:
880890
def true_nested(y: torch.Tensor) -> torch.Tensor:
@@ -928,7 +938,8 @@ def check_debug_handle_metadata(graph_module: torch.fx.GraphModule) -> None:
928938
while queue:
929939
current_graph_module = queue.pop(0)
930940
for node in current_graph_module.graph.nodes:
931-
self.assertIn("debug_handle", node.meta)
941+
if node.op != "placeholder" and node.op != "output":
942+
self.assertIn("debug_handle", node.meta)
932943
control_flow_submodules = [
933944
submodule
934945
for _, submodule, _ in get_control_flow_submodules(
@@ -939,7 +950,6 @@ def check_debug_handle_metadata(graph_module: torch.fx.GraphModule) -> None:
939950

940951
DebugHandleGeneratorPass()(graph_module)
941952
check_debug_handle_metadata(graph_module)
942-
generate_missing_debug_handles(ep)
943953

944954
# Check debug handle still preserved after ScalarToTensorPass
945955
ScalarToTensorPass()(graph_module)

0 commit comments

Comments
 (0)