Skip to content

Commit 5774995

Browse files
committed
propagate debug handle from edge dialect graph back to exported graph
Using exported graph from torch.export as source of truth for aot intermediate output is our target. Once thing blocking us is the exported graph does not have debug handle, which will not show up in the export flow until DebugHandleGenerationPass, the last step of to_edge(). We need to equip the graph with same debug handle used in ExecuTorch flow. This diff creates a propagate_back_debug_handle function, which propagate debug handle from edge dialect program back to the exported program while maintain the correctness of operator tracing. Differential Revision: [D78051614](https://our.internmc.facebook.com/intern/diff/D78051614/) [ghstack-poisoned]
1 parent c9d7bc8 commit 5774995

File tree

6 files changed

+230
-21
lines changed

6 files changed

+230
-21
lines changed

devtools/inspector/_inspector_utils.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,17 @@
3535
from executorch.devtools.etdump.serialize import deserialize_from_etdump_flatcc
3636
from executorch.devtools.etrecord import ETRecord
3737

38+
from executorch.exir.debug_handle_utils import (
39+
DEBUG_HANDLE_KEY,
40+
get_greatest_ancestor_node_identifier,
41+
)
42+
43+
from executorch.exir.graph_module import bfs_trace_with_node_process
44+
3845
from tabulate import tabulate
3946

47+
from torch.export import ExportedProgram
48+
4049
FORWARD = "forward"
4150
EDGE_DIALECT_GRAPH_KEY = "edge_dialect_graph_module"
4251

@@ -888,3 +897,69 @@ def compare_intermediate_outputs(a: Any, b: Any, comparator) -> List[float]:
888897
else:
889898
# Raise an error if one is a sequence and the other is not
890899
raise ValueError("Both inputs must be sequences or both must be non-sequences.")
900+
901+
902+
def propagate_back_debug_handle(
903+
exported_program: ExportedProgram,
904+
exported_program_graph_id: int,
905+
edge_dialect_program: ExportedProgram,
906+
) -> bool:
907+
"""
908+
Propagate debug handle from edge dialect program back to the exported program while maintain the correctness
909+
of operator tracing.
910+
911+
e.g.
912+
export program: op1 -> op2 -> op3
913+
edge dialect program: op1_0 -> op3_0 -> op3_1
914+
where op1_0 is from op1, op3_0 and op3_1 are from op3, op2 is removed by to_edge pipeline (e.g. RemoveNoopPass).
915+
916+
Then debug handle of op1 should be same as op1_0, and debug handle of op3 should be same as op3_0 and op3_1.
917+
The debug handle of op2 will be a non-existing debug handle in edge dialect program for further skipping.
918+
919+
Return: True if:
920+
a. every debug handle in the edge dialect program has a corresponding node in the exported program
921+
b. the exported program is the greatest ancestor of the edge dialect program
922+
923+
Otherwise, return False.
924+
"""
925+
926+
# 1. set up a mapping from debug handle to identifier of export program's node
927+
# using edge dialect program nodes' debug handles and from_node info
928+
export_graph_node_id_to_debug_handle = {
929+
get_greatest_ancestor_node_identifier(node): node.meta[DEBUG_HANDLE_KEY]
930+
for node in edge_dialect_program.graph.nodes
931+
if node.op not in ("placeholder", "output")
932+
}
933+
934+
# 2. equip debug handle to the exported program's nodes using the mapping
935+
# number of nodes in the exported program that have matched entry in export_graph_node_id_to_debug_handle
936+
n_matched_node = 0
937+
938+
# debug handle for the node in the exported program but not in the edge dialect program
939+
debug_handle_for_removed_node = (
940+
max(export_graph_node_id_to_debug_handle.values()) + 1
941+
)
942+
943+
def _find_n_match_node(node: torch.fx.Node) -> None:
944+
nonlocal n_matched_node
945+
if node.name in ("output", "placeholder"):
946+
return
947+
node_id = f"{node.name}.{exported_program_graph_id}"
948+
if node_id in export_graph_node_id_to_debug_handle:
949+
n_matched_node += 1
950+
951+
def _equip_debug_handle(node: torch.fx.Node) -> None:
952+
if node.name in ("output", "placeholder"):
953+
return
954+
node_id = f"{node.name}.{exported_program_graph_id}"
955+
if node_id in export_graph_node_id_to_debug_handle:
956+
node.meta[DEBUG_HANDLE_KEY] = export_graph_node_id_to_debug_handle[node_id]
957+
else:
958+
node.meta[DEBUG_HANDLE_KEY] = debug_handle_for_removed_node
959+
960+
bfs_trace_with_node_process(exported_program.graph_module, _find_n_match_node)
961+
if n_matched_node != len(export_graph_node_id_to_debug_handle):
962+
return False
963+
964+
bfs_trace_with_node_process(exported_program.graph_module, _equip_debug_handle)
965+
return True

devtools/inspector/tests/inspector_utils_test.py

Lines changed: 113 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@
1010
import unittest
1111
from typing import Dict, Tuple
1212

13-
import torch
13+
import executorch.exir.tests.models as models
1414

15+
import torch
1516
from executorch.devtools import generate_etrecord, parse_etrecord
1617

1718
from executorch.devtools.debug_format.base_schema import (
@@ -41,9 +42,13 @@
4142
map_runtime_aot_intermediate_outputs,
4243
merge_runtime_overlapping_debug_handles,
4344
NodeFilter,
45+
propagate_back_debug_handle,
4446
TimeScale,
4547
)
4648
from executorch.devtools.inspector.numerical_comparator import L1Comparator
49+
from executorch.exir import EdgeCompileConfig, to_edge
50+
from executorch.exir.debug_handle_utils import DEBUG_HANDLE_KEY
51+
from torch.export import export
4752

4853

4954
class TestInspectorUtils(unittest.TestCase):
@@ -583,6 +588,113 @@ def test_compare_intermediate_outputs_sequence_and_non_sequence(self):
583588
with self.assertRaises(ValueError):
584589
compare_intermediate_outputs(a, b, L1Comparator())
585590

591+
def test_equip_debug_handle_to_export_program_success(self):
592+
"""Test that propagate_back_debug_handle returns True and properly equips debug handles."""
593+
# Create a test model
594+
model = models.FeedForwardBlock(5, 10)
595+
inputs = (torch.rand(5, 5),)
596+
597+
# Export the model
598+
exported_program = export(model, inputs)
599+
export_graph_id = id(exported_program.graph)
600+
601+
# Convert to edge dialect
602+
edge_dialect_program = to_edge(exported_program).exported_program()
603+
604+
# Call propagate_back_debug_handle
605+
result = propagate_back_debug_handle(
606+
exported_program, export_graph_id, edge_dialect_program
607+
)
608+
609+
self.assertTrue(result)
610+
611+
# Check that debug handles are properly equipped in the exported program
612+
exported_program_debug_handles = []
613+
for node in exported_program.graph.nodes:
614+
if node.op not in ("placeholder", "output"):
615+
self.assertIn(DEBUG_HANDLE_KEY, node.meta)
616+
self.assertIsNotNone(node.meta[DEBUG_HANDLE_KEY])
617+
exported_program_debug_handles.append(node.meta[DEBUG_HANDLE_KEY])
618+
619+
edge_dialect_program_debug_handles = []
620+
for node in edge_dialect_program.graph.nodes:
621+
if node.op not in ("placeholder", "output"):
622+
self.assertIn(DEBUG_HANDLE_KEY, node.meta)
623+
self.assertIsNotNone(node.meta[DEBUG_HANDLE_KEY])
624+
edge_dialect_program_debug_handles.append(node.meta[DEBUG_HANDLE_KEY])
625+
626+
# The 0th operator in the exported program (layer_norm) has been decomposed into 0th and 1st ops in edge dialect graph (native_layer_norm and getitem)
627+
# So they should have the same debug handle
628+
self.assertEqual(
629+
exported_program_debug_handles[0], edge_dialect_program_debug_handles[0]
630+
)
631+
self.assertEqual(
632+
exported_program_debug_handles[0], edge_dialect_program_debug_handles[1]
633+
)
634+
635+
def test_equip_debug_handle_to_export_program_failure(self):
636+
"""Test that propagate_back_debug_handle returns False when there's a mismatch."""
637+
# Create a test model
638+
model = models.FeedForwardBlock(5, 10)
639+
inputs = (torch.rand(5, 5),)
640+
641+
exported_program = export(model, inputs)
642+
edge_dialect_program = to_edge(exported_program).exported_program()
643+
644+
# Create a different exported program (reexport) to cause mismatch
645+
reexported_program = export(model, inputs)
646+
reexport_graph_id = id(reexported_program.graph)
647+
648+
# Call propagate_back_debug_handle with mismatched programs
649+
# This should return False because the reexported program has different node identifiers
650+
result = propagate_back_debug_handle(
651+
reexported_program, reexport_graph_id, edge_dialect_program
652+
)
653+
654+
# Check that it returns False due to mismatch
655+
self.assertFalse(result)
656+
657+
def test_equip_debug_handle_to_export_program_op_to_be_removed_in_to_edge(self):
658+
"""Test that propagate_back_debug_handle returns True and properly equips debug handles when an op is removed in to_edge"""
659+
660+
class M(torch.nn.Module):
661+
"""
662+
Simple model with ops that will be removed in to_edge
663+
"""
664+
665+
def __init__(self) -> None:
666+
super().__init__()
667+
668+
def forward(self, x: torch.Tensor) -> torch.Tensor:
669+
x = x + 1
670+
x = x.to(x.dtype)
671+
x = x + 1
672+
return x
673+
674+
inputs = (torch.rand(5, 5),)
675+
exported_program = torch.export.export(M(), inputs)
676+
export_graph_id = id(exported_program.graph)
677+
edge_dialect_program = to_edge(exported_program).exported_program()
678+
679+
self.assertTrue(
680+
propagate_back_debug_handle(
681+
exported_program, export_graph_id, edge_dialect_program
682+
)
683+
)
684+
685+
# only two add ops in the exported program will keep in edge dialect program, so the debug handles for removed op will be three
686+
debug_handle_for_removed_node = 3
687+
688+
for node in exported_program.graph.nodes:
689+
if node.name == "add":
690+
self.assertEqual(node.meta[DEBUG_HANDLE_KEY], 1)
691+
elif node.name == "add_1":
692+
self.assertEqual(node.meta[DEBUG_HANDLE_KEY], 2)
693+
elif node.op not in ("placeholder", "output"):
694+
self.assertEqual(
695+
node.meta[DEBUG_HANDLE_KEY], debug_handle_for_removed_node
696+
)
697+
586698

587699
def gen_mock_operator_graph_with_expected_map() -> (
588700
Tuple[OperatorGraph, Dict[int, OperatorNode]]

exir/TARGETS

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,3 +277,11 @@ python_library(
277277
"fbsource//third-party/pypi/typing-extensions:typing-extensions",
278278
],
279279
)
280+
281+
python_library(
282+
name = "debug_handle_utils",
283+
srcs = ["debug_handle_utils.py"],
284+
deps = [
285+
"//caffe2:torch",
286+
],
287+
)

exir/debug_handle_utils.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from torch.fx import Node
8+
9+
FROM_NODE_KEY = "from_node"
10+
DEBUG_HANDLE_KEY = "debug_handle"
11+
12+
13+
def get_greatest_ancestor_node_identifier(node: Node) -> str:
14+
"""Get the identifier of the greatest ancestor node of the given node.
15+
16+
The identifier is the concatenation of the node name and graph id of the
17+
greatest ancestor node, where the graph id is the unique id for every graph
18+
module in the export flow and node name is unique within the same graph module.
19+
"""
20+
21+
node_source = node.meta[FROM_NODE_KEY]
22+
node_source = node_source[-1]
23+
24+
while len(node_source.from_node) > 0:
25+
node_source = node_source.from_node[-1]
26+
27+
return f"{node_source.name}.{str(node_source.graph_id)}"

exir/passes/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,7 @@ python_library(
342342
],
343343
deps = [
344344
"//caffe2:torch",
345+
"//executorch/exir:debug_handle_utils",
345346
"//executorch/exir:graph_module",
346347
"//executorch/exir:pass_base",
347348
],

exir/passes/debug_handle_generator_pass.py

Lines changed: 6 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,11 @@
66

77
from typing import Dict
88

9+
from executorch.exir.debug_handle_utils import (
10+
DEBUG_HANDLE_KEY,
11+
FROM_NODE_KEY,
12+
get_greatest_ancestor_node_identifier,
13+
)
914
from executorch.exir.graph_module import bfs_trace_with_node_process
1015
from executorch.exir.pass_base import ExportPass
1116
from torch.export import ExportedProgram
@@ -21,27 +26,8 @@ def call(self, graph_module: GraphModule) -> PassResult:
2126
greatest ancestor node in the export flow.
2227
"""
2328

24-
FROM_NODE_KEY = "from_node"
25-
DEBUG_HANDLE_KEY = "debug_handle"
26-
2729
source_node_id_to_debug_handle: Dict[str, int] = {}
2830

29-
def _get_greatest_ancestor_node_identifier(node: Node) -> str:
30-
"""Get the identifier of the greatest ancestor node of the given node.
31-
32-
The identifier is the concatenation of the node name and graph id of the
33-
greatest ancestor node, where the graph id is the unique id for every graph
34-
module in the export flow and node name is unique within the same graph module.
35-
"""
36-
37-
node_source = node.meta[FROM_NODE_KEY]
38-
node_source = node_source[-1]
39-
40-
while len(node_source.from_node) > 0:
41-
node_source = node_source.from_node[-1]
42-
43-
return node_source.name + str(node_source.graph_id)
44-
4531
def _extract_debug_handles_from_node(node: Node) -> None:
4632
"""
4733
Generate a debug handle based on node's oldest ancestor node's name
@@ -56,7 +42,7 @@ def _extract_debug_handles_from_node(node: Node) -> None:
5642
FROM_NODE_KEY in node.meta
5743
), f"Node {node} does not have meta key {FROM_NODE_KEY}"
5844

59-
greatest_ancestor_node_id = _get_greatest_ancestor_node_identifier(node)
45+
greatest_ancestor_node_id = get_greatest_ancestor_node_identifier(node)
6046

6147
debug_handle = (
6248
len(source_node_id_to_debug_handle) + 1

0 commit comments

Comments
 (0)