Skip to content

Commit fe703c7

Browse files
Gasoonjiafacebook-github-bot
authored andcommitted
generate debug handle before opeartor decomposition (#11949)
Summary: Rollback Plan: Differential Revision: D76860368
1 parent 42195da commit fe703c7

File tree

4 files changed

+74
-26
lines changed

4 files changed

+74
-26
lines changed

devtools/inspector/_inspector.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -654,7 +654,7 @@ def _populate_debugging_related_fields(
654654

655655
def _associate_with_op_graph_nodes(
656656
self,
657-
debug_handle_to_op_node_map: Dict[int, OperatorNode],
657+
debug_handle_to_op_node_map: Dict[int, List[OperatorNode]],
658658
) -> None:
659659
"""
660660
Helper function to populate the stack_traces, module_hierarchy and op_types attributes
@@ -672,14 +672,21 @@ def _associate_with_op_graph_nodes(
672672
debug_handles = [debug_handles]
673673

674674
for handle in debug_handles:
675-
node = debug_handle_to_op_node_map.get(handle)
676-
# Attach node metadata including stack traces, module hierarchy and op_types to this event
677-
if node is not None and (metadata := node.metadata) is not None:
678-
self.stack_traces[node.name] = metadata.get("stack_trace")
679-
self.module_hierarchy[node.name] = metadata.get("nn_module_stack")
680-
if node.op:
681-
# TODO: consider having this as a dict from node.name -> node.op
682-
self.op_types += [node.op]
675+
nodes = debug_handle_to_op_node_map.get(handle, None)
676+
if nodes is None:
677+
continue
678+
679+
for node in nodes:
680+
# Attach node metadata including stack traces, module hierarchy and op_types to this event
681+
if node is not None and (metadata := node.metadata) is not None:
682+
if self.asdict().get("stack_traces") is None:
683+
self.stack_traces[node.name] = metadata.get("stack_trace")
684+
self.module_hierarchy[node.name] = metadata.get(
685+
"nn_module_stack"
686+
)
687+
if node.op:
688+
# TODO: consider having this as a dict from node.name -> node.op
689+
self.op_types += [node.op]
683690

684691

685692
@dataclass

devtools/inspector/_inspector_utils.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -279,14 +279,18 @@ def gen_graphs_from_etrecord(
279279
return op_graph_map
280280

281281

282+
# One debug handle should only be associated with one node. We are in the middle of migrating debug handle generation
283+
# from graph after to_edge to graph after torch.export, one every debug handle in exported graph may be associated with multiple nodes in to_edge
284+
# graph. After fully migration, we should bring the bring type as well as the #node check back.
285+
# TODO(gasoonjia): recover the return type to Dict[int, List[OperatorNode], reenable the #node check.
282286
def create_debug_handle_to_op_node_mapping(
283287
op_graph: OperatorGraph,
284-
) -> Dict[int, OperatorNode]:
288+
) -> Dict[int, List[OperatorNode]]:
285289
"""
286290
Recursive function to traverse all the operator graph nodes of input op_graph and build a mapping
287291
from each debug handle to the operator node that contains the debug handle in its metadata.
288292
"""
289-
debug_handle_to_op_node_map: Dict[int, OperatorNode] = {}
293+
debug_handle_to_op_node_map: Dict[int, List[OperatorNode]] = {}
290294

291295
# Recursively searches through the metadata of nodes
292296
def _extract_debug_handles(graph: OperatorGraph):
@@ -296,14 +300,13 @@ def _extract_debug_handles(graph: OperatorGraph):
296300
if isinstance(element, OperatorNode) and element.metadata is not None:
297301
metadata = element.metadata
298302
debug_handle = metadata.get("debug_handle")
299-
if debug_handle is not None:
300-
existing_entry = debug_handle_to_op_node_map.get(debug_handle)
301-
if existing_entry is not None:
302-
raise ValueError(
303-
f"Duplicated debug handle {str(debug_handle)} shared between {element.name} and {existing_entry.name}. "
304-
"No two op nodes of the same graph should have the same debug handle."
305-
)
306-
debug_handle_to_op_node_map[debug_handle] = element
303+
if debug_handle is None:
304+
continue
305+
306+
if debug_handle not in debug_handle_to_op_node_map:
307+
debug_handle_to_op_node_map[debug_handle] = []
308+
309+
debug_handle_to_op_node_map[debug_handle].append(element)
307310

308311
# Start traversing
309312
_extract_debug_handles(op_graph)

exir/passes/debug_handle_generator_pass.py

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
from typing import Dict
8+
79
from executorch.exir.graph_module import bfs_trace_with_node_process
810
from executorch.exir.pass_base import ExportPass
911
from torch.export import ExportedProgram
10-
from torch.fx import GraphModule
12+
from torch.fx import GraphModule, Node
1113
from torch.fx.passes.infra.pass_base import PassResult
1214

1315

@@ -17,12 +19,48 @@ def call(self, graph_module: GraphModule) -> PassResult:
1719
to executorch backend, that has a canonical set of quantized operators
1820
"""
1921

20-
index = 1
22+
FROM_NODE_KEY = "from_node"
23+
DEBUG_HANDLE_KEY = "debug_handle"
24+
25+
source_node_to_debug_handle: Dict[str, int] = {}
26+
27+
def _get_greatest_ancestor_node_source(node: Node):
28+
node_source = node.meta[FROM_NODE_KEY]
29+
node_source = node_source[-1]
30+
31+
while len(node_source.from_node) > 0:
32+
node_source = node_source.from_node[-1]
33+
34+
return node_source
35+
36+
def _extract_debug_handles_from_node(node: Node) -> None:
37+
"""
38+
Generate a debug handle based on node's oldest ancestor node's name
39+
and graph id, or return None if the node does not need to be traced.
40+
"""
41+
42+
if node.op == "placeholder" or node.op == "output":
43+
# placeholder and output nodes don't have debug handle
44+
return
45+
46+
assert (
47+
FROM_NODE_KEY in node.meta
48+
), f"Node {node} does not have meta key {FROM_NODE_KEY}"
49+
50+
greatest_ancestor_node_source = _get_greatest_ancestor_node_source(node)
51+
52+
source_node = greatest_ancestor_node_source.name + str(
53+
greatest_ancestor_node_source.graph_id
54+
)
55+
56+
debug_handle = (
57+
len(source_node_to_debug_handle) + 1
58+
if source_node not in source_node_to_debug_handle
59+
else source_node_to_debug_handle[source_node]
60+
)
61+
source_node_to_debug_handle[source_node] = debug_handle
2162

22-
def _extract_debug_handles_from_node(node):
23-
nonlocal index
24-
node.meta["debug_handle"] = index
25-
index += 1
63+
node.meta[DEBUG_HANDLE_KEY] = debug_handle
2664

2765
bfs_trace_with_node_process(graph_module, _extract_debug_handles_from_node)
2866

exir/tests/test_passes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -939,7 +939,7 @@ def check_debug_handle_metadata(graph_module: torch.fx.GraphModule) -> None:
939939

940940
DebugHandleGeneratorPass()(graph_module)
941941
check_debug_handle_metadata(graph_module)
942-
generate_missing_debug_handles(ep)
942+
# generate_missing_debug_handles(ep)
943943

944944
# Check debug handle still preserved after ScalarToTensorPass
945945
ScalarToTensorPass()(graph_module)

0 commit comments

Comments
 (0)