Skip to content

Commit d5a03a2

Browse files
committed
support back propagate debug handle to arbitrary ancestor export graph
Pull Request resolved: #12580 Currently propagate_back_debug_handle function can only support propagating debug handle back to the greatest ancestor export graph. This diff update algo to support every possible ancestor export graph on the flow. Differential Revision: [D78464992](https://our.internmc.facebook.com/intern/diff/D78464992/) ghstack-source-id: 296760264
1 parent 1af653c commit d5a03a2

File tree

2 files changed

+234
-40
lines changed

2 files changed

+234
-40
lines changed

devtools/inspector/_inspector_utils.py

Lines changed: 145 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from collections.abc import Sequence
1212
from dataclasses import dataclass
1313
from enum import Enum
14-
from typing import Any, Dict, IO, List, Mapping, Optional, Tuple, TypeAlias, Union
14+
from typing import Any, Dict, IO, List, Mapping, Optional, Set, Tuple, TypeAlias, Union
1515

1616
import executorch.devtools.etdump.schema_flatcc as flatcc
1717

@@ -37,7 +37,7 @@
3737

3838
from executorch.exir.debug_handle_utils import (
3939
DEBUG_HANDLE_KEY,
40-
get_greatest_ancestor_node_identifier,
40+
FROM_NODE_KEY,
4141
UNSET_DEBUG_HANDLE,
4242
)
4343

@@ -46,6 +46,7 @@
4646
from tabulate import tabulate
4747

4848
from torch.export import ExportedProgram
49+
from torch.fx import Node
4950

5051
FORWARD = "forward"
5152
EDGE_DIALECT_GRAPH_KEY = "edge_dialect_graph_module"
@@ -936,6 +937,133 @@ def compare_intermediate_outputs(a: Any, b: Any, comparator) -> List[float]:
936937
)
937938

938939

940+
def get_ancestor_node_identifiers(node: Node) -> List[str]:
941+
"""Get the identifier of the ancestor node of the given node, with the graph id the ancestor node lives in.
942+
943+
The identifier is the concatenation of the node name and graph id of the
944+
greatest ancestor node, where the graph id is the unique id for every graph
945+
module in the export flow and node name is unique within the same graph module.
946+
947+
Returns: the identifiers of all its ancestor nodes
948+
"""
949+
950+
node_source = node.meta[FROM_NODE_KEY]
951+
node_source = node_source[-1]
952+
ancestor_node_ids: List[str] = [f"{node_source.name}.{str(node_source.graph_id)}"]
953+
954+
while len(node_source.from_node) > 0:
955+
node_source = node_source.from_node[-1]
956+
ancestor_node_ids.append(f"{node_source.name}.{str(node_source.graph_id)}")
957+
958+
return ancestor_node_ids
959+
960+
961+
def get_parent_node_identifier(node: Node) -> Optional[str]:
962+
"""Get the identifier of the parent node of the given node, with the graph id the parent node lives in.
963+
964+
The identifier is the concatenation of the node name and graph id of the
965+
greatest parent node, where the graph id is the unique id for every graph
966+
module in the export flow and node name is unique within the same graph module.
967+
968+
Returns: the identifier of the parent node, or None if can not find the parent
969+
"""
970+
971+
if FROM_NODE_KEY not in node.meta:
972+
return None
973+
974+
node_source = node.meta[FROM_NODE_KEY][-1]
975+
return f"{node_source.name}.{str(node_source.graph_id)}"
976+
977+
978+
def _extract_ancestor_debug_handles(
979+
edge_dialect_program: ExportedProgram,
980+
) -> Dict[str, int]:
981+
"""Extract mapping from ancestor node identifiers to debug handles."""
982+
ancestors_node_id_to_debug_handle: Dict[str, int] = {}
983+
984+
def _extract_node_id_to_debug_handle(node: Node) -> None:
985+
if node.op in ("placeholder", "output"):
986+
return
987+
for ancestor_node_id in get_ancestor_node_identifiers(node):
988+
if ancestor_node_id not in ancestors_node_id_to_debug_handle:
989+
ancestors_node_id_to_debug_handle[ancestor_node_id] = node.meta[
990+
DEBUG_HANDLE_KEY
991+
]
992+
else:
993+
assert (
994+
ancestors_node_id_to_debug_handle[ancestor_node_id]
995+
== node.meta[DEBUG_HANDLE_KEY]
996+
)
997+
998+
bfs_trace_with_node_process(
999+
edge_dialect_program.graph_module, _extract_node_id_to_debug_handle
1000+
)
1001+
return ancestors_node_id_to_debug_handle
1002+
1003+
1004+
def _find_matched_debug_handles(
1005+
exported_program: ExportedProgram,
1006+
exported_program_graph_id: int,
1007+
ancestors_node_id_to_debug_handle: Dict[str, int],
1008+
) -> Set[int]:
1009+
"""Find debug handles that have corresponding nodes in the exported program."""
1010+
matched_debug_handles: Set[int] = set()
1011+
1012+
def _find_n_match_node(node: Node) -> None:
1013+
if node.op in ("output", "placeholder"):
1014+
return
1015+
node_id = f"{node.name}.{exported_program_graph_id}"
1016+
parent_node_id = get_parent_node_identifier(node)
1017+
if node_id in ancestors_node_id_to_debug_handle:
1018+
matched_debug_handles.add(ancestors_node_id_to_debug_handle[node_id])
1019+
elif parent_node_id and parent_node_id in ancestors_node_id_to_debug_handle:
1020+
matched_debug_handles.add(ancestors_node_id_to_debug_handle[parent_node_id])
1021+
1022+
bfs_trace_with_node_process(exported_program.graph_module, _find_n_match_node)
1023+
return matched_debug_handles
1024+
1025+
1026+
def _verify_graph_match(
1027+
edge_dialect_program: ExportedProgram, matched_debug_handles: Set[int]
1028+
) -> bool:
1029+
"""Verify if every debug handle in edge dialect program has a corresponding node."""
1030+
graph_matched = True
1031+
1032+
def _check_graph_match(node: Node) -> None:
1033+
nonlocal graph_matched
1034+
if node.op in ("output", "placeholder"):
1035+
return
1036+
if node.meta[DEBUG_HANDLE_KEY] not in matched_debug_handles:
1037+
graph_matched = False
1038+
1039+
bfs_trace_with_node_process(edge_dialect_program.graph_module, _check_graph_match)
1040+
return graph_matched
1041+
1042+
1043+
def _apply_debug_handles(
1044+
exported_program: ExportedProgram,
1045+
exported_program_graph_id: int,
1046+
ancestors_node_id_to_debug_handle: Dict[str, int],
1047+
) -> None:
1048+
"""Apply debug handles to the exported program nodes."""
1049+
1050+
def _equip_debug_handle(node: Node) -> None:
1051+
if node.op in ("output", "placeholder"):
1052+
return
1053+
node_id = f"{node.name}.{exported_program_graph_id}"
1054+
parent_node_id = get_parent_node_identifier(node)
1055+
if node_id in ancestors_node_id_to_debug_handle:
1056+
node.meta[DEBUG_HANDLE_KEY] = ancestors_node_id_to_debug_handle[node_id]
1057+
elif parent_node_id and parent_node_id in ancestors_node_id_to_debug_handle:
1058+
node.meta[DEBUG_HANDLE_KEY] = ancestors_node_id_to_debug_handle[
1059+
parent_node_id
1060+
]
1061+
else:
1062+
node.meta[DEBUG_HANDLE_KEY] = UNSET_DEBUG_HANDLE
1063+
1064+
bfs_trace_with_node_process(exported_program.graph_module, _equip_debug_handle)
1065+
1066+
9391067
def propagate_back_debug_handle(
9401068
exported_program: ExportedProgram,
9411069
exported_program_graph_id: int,
@@ -953,47 +1081,24 @@ def propagate_back_debug_handle(
9531081
Then debug handle of op1 should be same as op1_0, and debug handle of op3 should be same as op3_0 and op3_1.
9541082
The debug handle of op2 will be UNSET_DEBUG_HANDLE for further skipping.
9551083
956-
Return: True if:
957-
a. every debug handle in the edge dialect program has a corresponding node in the exported program
958-
b. the exported program is the greatest ancestor of the edge dialect program
959-
960-
Otherwise, return False.
1084+
Return: True if every debug handle in the edge dialect program has a corresponding node in the exported program, otherwise, return False.
9611085
"""
1086+
# 1. Extract mapping from ancestor node identifiers to debug handles
1087+
ancestors_node_id_to_debug_handle = _extract_ancestor_debug_handles(
1088+
edge_dialect_program
1089+
)
9621090

963-
# 1. set up a mapping from debug handle to identifier of export program's node
964-
# using edge dialect program nodes' debug handles and from_node info
965-
export_graph_node_id_to_debug_handle = {
966-
get_greatest_ancestor_node_identifier(node): node.meta[DEBUG_HANDLE_KEY]
967-
for node in edge_dialect_program.graph.nodes
968-
if node.op not in ("placeholder", "output")
969-
}
970-
971-
# 2. equip debug handle to the exported program's nodes using the mapping
972-
# number of nodes in the exported program that have matched entry in export_graph_node_id_to_debug_handle
973-
n_matched_node = 0
974-
975-
def _find_n_match_node(node: torch.fx.Node) -> None:
976-
nonlocal n_matched_node
977-
if node.name in ("output", "placeholder"):
978-
return
979-
node_id = f"{node.name}.{exported_program_graph_id}"
980-
if node_id in export_graph_node_id_to_debug_handle:
981-
n_matched_node += 1
982-
983-
def _equip_debug_handle(node: torch.fx.Node) -> None:
984-
if node.name in ("output", "placeholder"):
985-
return
986-
node_id = f"{node.name}.{exported_program_graph_id}"
987-
if node_id in export_graph_node_id_to_debug_handle:
988-
node.meta[DEBUG_HANDLE_KEY] = export_graph_node_id_to_debug_handle[node_id]
989-
else:
990-
node.meta[DEBUG_HANDLE_KEY] = UNSET_DEBUG_HANDLE
991-
992-
bfs_trace_with_node_process(exported_program.graph_module, _find_n_match_node)
1091+
# 2. Find debug handles that have corresponding nodes in the exported program
1092+
matched_debug_handles = _find_matched_debug_handles(
1093+
exported_program, exported_program_graph_id, ancestors_node_id_to_debug_handle
1094+
)
9931095

994-
# if any node in the edge dialect program has no corresponding node in the exported program, match failed
995-
if n_matched_node != len(export_graph_node_id_to_debug_handle):
1096+
# 3. Verify if every debug handle in edge dialect program has a corresponding node
1097+
if not _verify_graph_match(edge_dialect_program, matched_debug_handles):
9961098
return False
9971099

998-
bfs_trace_with_node_process(exported_program.graph_module, _equip_debug_handle)
1100+
# 4. Apply debug handles to the exported program
1101+
_apply_debug_handles(
1102+
exported_program, exported_program_graph_id, ancestors_node_id_to_debug_handle
1103+
)
9991104
return True

devtools/inspector/tests/inspector_utils_test.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -654,6 +654,95 @@ def test_equip_debug_handle_to_export_program_success(self):
654654
exported_program_debug_handles[0], edge_dialect_program_debug_handles[1]
655655
)
656656

657+
def test_equip_debug_handle_to_strict_export_program_success(self):
658+
"""Test that propagate_back_debug_handle returns True and properly equips debug handles."""
659+
# Create a test model
660+
model = models.FeedForwardBlock(5, 10)
661+
inputs = (torch.rand(5, 5),)
662+
663+
# Export the model
664+
exported_program = export(model, inputs, strict=True)
665+
export_graph_id = id(exported_program.graph)
666+
667+
# Convert to edge dialect
668+
edge_dialect_program = to_edge(exported_program).exported_program()
669+
670+
# Call propagate_back_debug_handle
671+
result = propagate_back_debug_handle(
672+
exported_program, export_graph_id, edge_dialect_program
673+
)
674+
675+
self.assertTrue(result)
676+
677+
# Check that debug handles are properly equipped in the exported program
678+
exported_program_debug_handles = []
679+
for node in exported_program.graph.nodes:
680+
if node.op not in ("placeholder", "output"):
681+
self.assertIn(DEBUG_HANDLE_KEY, node.meta)
682+
self.assertIsNotNone(node.meta[DEBUG_HANDLE_KEY])
683+
exported_program_debug_handles.append(node.meta[DEBUG_HANDLE_KEY])
684+
685+
edge_dialect_program_debug_handles = []
686+
for node in edge_dialect_program.graph.nodes:
687+
if node.op not in ("placeholder", "output"):
688+
self.assertIn(DEBUG_HANDLE_KEY, node.meta)
689+
self.assertIsNotNone(node.meta[DEBUG_HANDLE_KEY])
690+
edge_dialect_program_debug_handles.append(node.meta[DEBUG_HANDLE_KEY])
691+
692+
# The 0th operator in the exported program (layer_norm) has been decomposed into 0th and 1st ops in edge dialect graph (native_layer_norm and getitem)
693+
# So they should have the same debug handle
694+
self.assertEqual(
695+
exported_program_debug_handles[0], edge_dialect_program_debug_handles[0]
696+
)
697+
self.assertEqual(
698+
exported_program_debug_handles[0], edge_dialect_program_debug_handles[1]
699+
)
700+
701+
def test_equip_debug_handle_to_reexport_program_success(self):
702+
"""Test that propagate_back_debug_handle returns True and properly equips debug handles."""
703+
# Create a test model
704+
model = models.FeedForwardBlock(5, 10)
705+
inputs = (torch.rand(5, 5),)
706+
707+
# Export the model
708+
init_export_program = export(model, inputs)
709+
exported_program = export(init_export_program.module(), inputs)
710+
export_graph_id = id(exported_program.graph)
711+
712+
# Convert to edge dialect
713+
edge_dialect_program = to_edge(exported_program).exported_program()
714+
715+
# Call propagate_back_debug_handle
716+
result = propagate_back_debug_handle(
717+
exported_program, export_graph_id, edge_dialect_program
718+
)
719+
720+
self.assertTrue(result)
721+
722+
# Check that debug handles are properly equipped in the exported program
723+
exported_program_debug_handles = []
724+
for node in exported_program.graph.nodes:
725+
if node.op not in ("placeholder", "output"):
726+
self.assertIn(DEBUG_HANDLE_KEY, node.meta)
727+
self.assertIsNotNone(node.meta[DEBUG_HANDLE_KEY])
728+
exported_program_debug_handles.append(node.meta[DEBUG_HANDLE_KEY])
729+
730+
edge_dialect_program_debug_handles = []
731+
for node in edge_dialect_program.graph.nodes:
732+
if node.op not in ("placeholder", "output"):
733+
self.assertIn(DEBUG_HANDLE_KEY, node.meta)
734+
self.assertIsNotNone(node.meta[DEBUG_HANDLE_KEY])
735+
edge_dialect_program_debug_handles.append(node.meta[DEBUG_HANDLE_KEY])
736+
737+
# The 0th operator in the exported program (layer_norm) has been decomposed into 0th and 1st ops in edge dialect graph (native_layer_norm and getitem)
738+
# So they should have the same debug handle
739+
self.assertEqual(
740+
exported_program_debug_handles[0], edge_dialect_program_debug_handles[0]
741+
)
742+
self.assertEqual(
743+
exported_program_debug_handles[0], edge_dialect_program_debug_handles[1]
744+
)
745+
657746
def test_equip_debug_handle_to_export_program_failure(self):
658747
"""Test that propagate_back_debug_handle returns False when there's a mismatch."""
659748
# Create a test model

0 commit comments

Comments
 (0)