Skip to content

Commit 4f612ff

Browse files
Gasoonjiafacebook-github-bot
authored andcommitted
generate debug handle before opeartor decomposition (#11949)
Summary: This diff update the debug handle generation, from each node in the edge program having a individual debug handle, to all nodes having a same ancestor in export graph sharing a same debug handle, which update the start point of tracing our node transformation from edge graph to exported graph. Differential Revision: D76860368
1 parent 42195da commit 4f612ff

File tree

4 files changed

+76
-26
lines changed

4 files changed

+76
-26
lines changed

devtools/inspector/_inspector.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -654,7 +654,7 @@ def _populate_debugging_related_fields(
654654

655655
def _associate_with_op_graph_nodes(
656656
self,
657-
debug_handle_to_op_node_map: Dict[int, OperatorNode],
657+
debug_handle_to_op_node_map: Dict[int, List[OperatorNode]],
658658
) -> None:
659659
"""
660660
Helper function to populate the stack_traces, module_hierarchy and op_types attributes
@@ -672,14 +672,21 @@ def _associate_with_op_graph_nodes(
672672
debug_handles = [debug_handles]
673673

674674
for handle in debug_handles:
675-
node = debug_handle_to_op_node_map.get(handle)
676-
# Attach node metadata including stack traces, module hierarchy and op_types to this event
677-
if node is not None and (metadata := node.metadata) is not None:
678-
self.stack_traces[node.name] = metadata.get("stack_trace")
679-
self.module_hierarchy[node.name] = metadata.get("nn_module_stack")
680-
if node.op:
681-
# TODO: consider having this as a dict from node.name -> node.op
682-
self.op_types += [node.op]
675+
nodes = debug_handle_to_op_node_map.get(handle, None)
676+
if nodes is None:
677+
continue
678+
679+
for node in nodes:
680+
# Attach node metadata including stack traces, module hierarchy and op_types to this event
681+
if node is not None and (metadata := node.metadata) is not None:
682+
if node.name not in self.stack_traces:
683+
self.stack_traces[node.name] = metadata.get("stack_trace")
684+
self.module_hierarchy[node.name] = metadata.get(
685+
"nn_module_stack"
686+
)
687+
if node.op:
688+
# TODO: consider having this as a dict from node.name -> node.op
689+
self.op_types += [node.op]
683690

684691

685692
@dataclass

devtools/inspector/_inspector_utils.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -279,14 +279,18 @@ def gen_graphs_from_etrecord(
279279
return op_graph_map
280280

281281

282+
# One debug handle should only be associated with one node. We are in the middle of migrating debug handle generation
283+
# from graph after to_edge to graph after torch.export, one every debug handle in exported graph may be associated with multiple nodes in to_edge
284+
# graph. After fully migration, we should bring the bring type as well as the #node check back.
285+
# TODO(gasoonjia): recover the return type to Dict[int, List[OperatorNode], reenable the #node check.
282286
def create_debug_handle_to_op_node_mapping(
283287
op_graph: OperatorGraph,
284-
) -> Dict[int, OperatorNode]:
288+
) -> Dict[int, List[OperatorNode]]:
285289
"""
286290
Recursive function to traverse all the operator graph nodes of input op_graph and build a mapping
287291
from each debug handle to the operator node that contains the debug handle in its metadata.
288292
"""
289-
debug_handle_to_op_node_map: Dict[int, OperatorNode] = {}
293+
debug_handle_to_op_node_map: Dict[int, List[OperatorNode]] = {}
290294

291295
# Recursively searches through the metadata of nodes
292296
def _extract_debug_handles(graph: OperatorGraph):
@@ -296,14 +300,13 @@ def _extract_debug_handles(graph: OperatorGraph):
296300
if isinstance(element, OperatorNode) and element.metadata is not None:
297301
metadata = element.metadata
298302
debug_handle = metadata.get("debug_handle")
299-
if debug_handle is not None:
300-
existing_entry = debug_handle_to_op_node_map.get(debug_handle)
301-
if existing_entry is not None:
302-
raise ValueError(
303-
f"Duplicated debug handle {str(debug_handle)} shared between {element.name} and {existing_entry.name}. "
304-
"No two op nodes of the same graph should have the same debug handle."
305-
)
306-
debug_handle_to_op_node_map[debug_handle] = element
303+
if debug_handle is None:
304+
continue
305+
306+
if debug_handle not in debug_handle_to_op_node_map:
307+
debug_handle_to_op_node_map[debug_handle] = []
308+
309+
debug_handle_to_op_node_map[debug_handle].append(element)
307310

308311
# Start traversing
309312
_extract_debug_handles(op_graph)

exir/passes/debug_handle_generator_pass.py

Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
from typing import Dict
8+
79
from executorch.exir.graph_module import bfs_trace_with_node_process
810
from executorch.exir.pass_base import ExportPass
911
from torch.export import ExportedProgram
10-
from torch.fx import GraphModule
12+
from torch.fx import GraphModule, Node
1113
from torch.fx.passes.infra.pass_base import PassResult
1214

1315

@@ -17,12 +19,50 @@ def call(self, graph_module: GraphModule) -> PassResult:
1719
to executorch backend, that has a canonical set of quantized operators
1820
"""
1921

20-
index = 1
22+
FROM_NODE_KEY = "from_node"
23+
DEBUG_HANDLE_KEY = "debug_handle"
24+
25+
source_node_to_debug_handle: Dict[str, int] = {}
26+
27+
def _get_greatest_ancestor_source_node(node: Node) -> str:
28+
""" Get the source of the greatest ancestor node of the given node. The source
29+
here means the name of the node concated with the id the graph it belongs to.
30+
For example, if the node transformation is node a -> b -> c, then the greatest
31+
ancestor node of c is a.
32+
"""
33+
34+
node_source = node.meta[FROM_NODE_KEY]
35+
node_source = node_source[-1]
36+
37+
while len(node_source.from_node) > 0:
38+
node_source = node_source.from_node[-1]
39+
40+
return node_source.name + str(node_source.graph_id)
41+
42+
def _extract_debug_handles_from_node(node: Node) -> None:
43+
"""
44+
Generate a debug handle based on node's oldest ancestor node's name
45+
and graph id, or return None if the node does not need to be traced.
46+
"""
47+
48+
if node.op == "placeholder" or node.op == "output":
49+
# placeholder and output nodes don't have debug handle
50+
return
51+
52+
assert (
53+
FROM_NODE_KEY in node.meta
54+
), f"Node {node} does not have meta key {FROM_NODE_KEY}"
55+
56+
source_node = _get_greatest_ancestor_source_node(node)
57+
58+
debug_handle = (
59+
len(source_node_to_debug_handle) + 1
60+
if source_node not in source_node_to_debug_handle
61+
else source_node_to_debug_handle[source_node]
62+
)
63+
source_node_to_debug_handle[source_node] = debug_handle
2164

22-
def _extract_debug_handles_from_node(node):
23-
nonlocal index
24-
node.meta["debug_handle"] = index
25-
index += 1
65+
node.meta[DEBUG_HANDLE_KEY] = debug_handle
2666

2767
bfs_trace_with_node_process(graph_module, _extract_debug_handles_from_node)
2868

exir/tests/test_passes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -939,7 +939,7 @@ def check_debug_handle_metadata(graph_module: torch.fx.GraphModule) -> None:
939939

940940
DebugHandleGeneratorPass()(graph_module)
941941
check_debug_handle_metadata(graph_module)
942-
generate_missing_debug_handles(ep)
942+
# generate_missing_debug_handles(ep)
943943

944944
# Check debug handle still preserved after ScalarToTensorPass
945945
ScalarToTensorPass()(graph_module)

0 commit comments

Comments
 (0)