1111from collections .abc import Sequence
1212from dataclasses import dataclass
1313from enum import Enum
14- from typing import Any , Dict , IO , List , Mapping , Optional , Tuple , TypeAlias , Union
14+ from typing import Any , Dict , IO , List , Mapping , Optional , Set , Tuple , TypeAlias , Union
1515
1616import executorch .devtools .etdump .schema_flatcc as flatcc
1717
3737
3838from executorch .exir .debug_handle_utils import (
3939 DEBUG_HANDLE_KEY ,
40- get_greatest_ancestor_node_identifier ,
40+ FROM_NODE_KEY ,
4141 UNSET_DEBUG_HANDLE ,
4242)
4343
4646from tabulate import tabulate
4747
4848from torch .export import ExportedProgram
49+ from torch .fx import Node
4950
5051FORWARD = "forward"
5152EDGE_DIALECT_GRAPH_KEY = "edge_dialect_graph_module"
@@ -936,6 +937,133 @@ def compare_intermediate_outputs(a: Any, b: Any, comparator) -> List[float]:
936937 )
937938
938939
940+ def get_ancestor_node_identifiers (node : Node ) -> List [str ]:
941+ """Get the identifier of the ancestor node of the given node, with the graph id the ancestor node lives in.
942+
943+ The identifier is the concatenation of the node name and graph id of the
944+ greatest ancestor node, where the graph id is the unique id for every graph
945+ module in the export flow and node name is unique within the same graph module.
946+
947+ Returns: the identifiers of all its ancestor nodes
948+ """
949+
950+ node_source = node .meta [FROM_NODE_KEY ]
951+ node_source = node_source [- 1 ]
952+ ancestor_node_ids : List [str ] = [f"{ node_source .name } .{ str (node_source .graph_id )} " ]
953+
954+ while len (node_source .from_node ) > 0 :
955+ node_source = node_source .from_node [- 1 ]
956+ ancestor_node_ids .append (f"{ node_source .name } .{ str (node_source .graph_id )} " )
957+
958+ return ancestor_node_ids
959+
960+
961+ def get_parent_node_identifier (node : Node ) -> Optional [str ]:
962+ """Get the identifier of the parent node of the given node, with the graph id the parent node lives in.
963+
964+ The identifier is the concatenation of the node name and graph id of the
965+ greatest parent node, where the graph id is the unique id for every graph
966+ module in the export flow and node name is unique within the same graph module.
967+
968+ Returns: the identifier of the parent node, or None if can not find the parent
969+ """
970+
971+ if FROM_NODE_KEY not in node .meta :
972+ return None
973+
974+ node_source = node .meta [FROM_NODE_KEY ][- 1 ]
975+ return f"{ node_source .name } .{ str (node_source .graph_id )} "
976+
977+
978+ def _extract_ancestor_debug_handles (
979+ edge_dialect_program : ExportedProgram ,
980+ ) -> Dict [str , int ]:
981+ """Extract mapping from ancestor node identifiers to debug handles."""
982+ ancestors_node_id_to_debug_handle : Dict [str , int ] = {}
983+
984+ def _extract_node_id_to_debug_handle (node : Node ) -> None :
985+ if node .op in ("placeholder" , "output" ):
986+ return
987+ for ancestor_node_id in get_ancestor_node_identifiers (node ):
988+ if ancestor_node_id not in ancestors_node_id_to_debug_handle :
989+ ancestors_node_id_to_debug_handle [ancestor_node_id ] = node .meta [
990+ DEBUG_HANDLE_KEY
991+ ]
992+ else :
993+ assert (
994+ ancestors_node_id_to_debug_handle [ancestor_node_id ]
995+ == node .meta [DEBUG_HANDLE_KEY ]
996+ )
997+
998+ bfs_trace_with_node_process (
999+ edge_dialect_program .graph_module , _extract_node_id_to_debug_handle
1000+ )
1001+ return ancestors_node_id_to_debug_handle
1002+
1003+
1004+ def _find_matched_debug_handles (
1005+ exported_program : ExportedProgram ,
1006+ exported_program_graph_id : int ,
1007+ ancestors_node_id_to_debug_handle : Dict [str , int ],
1008+ ) -> Set [int ]:
1009+ """Find debug handles that have corresponding nodes in the exported program."""
1010+ matched_debug_handles : Set [int ] = set ()
1011+
1012+ def _find_n_match_node (node : Node ) -> None :
1013+ if node .op in ("output" , "placeholder" ):
1014+ return
1015+ node_id = f"{ node .name } .{ exported_program_graph_id } "
1016+ parent_node_id = get_parent_node_identifier (node )
1017+ if node_id in ancestors_node_id_to_debug_handle :
1018+ matched_debug_handles .add (ancestors_node_id_to_debug_handle [node_id ])
1019+ elif parent_node_id and parent_node_id in ancestors_node_id_to_debug_handle :
1020+ matched_debug_handles .add (ancestors_node_id_to_debug_handle [parent_node_id ])
1021+
1022+ bfs_trace_with_node_process (exported_program .graph_module , _find_n_match_node )
1023+ return matched_debug_handles
1024+
1025+
1026+ def _verify_graph_match (
1027+ edge_dialect_program : ExportedProgram , matched_debug_handles : Set [int ]
1028+ ) -> bool :
1029+ """Verify if every debug handle in edge dialect program has a corresponding node."""
1030+ graph_matched = True
1031+
1032+ def _check_graph_match (node : Node ) -> None :
1033+ nonlocal graph_matched
1034+ if node .op in ("output" , "placeholder" ):
1035+ return
1036+ if node .meta [DEBUG_HANDLE_KEY ] not in matched_debug_handles :
1037+ graph_matched = False
1038+
1039+ bfs_trace_with_node_process (edge_dialect_program .graph_module , _check_graph_match )
1040+ return graph_matched
1041+
1042+
1043+ def _apply_debug_handles (
1044+ exported_program : ExportedProgram ,
1045+ exported_program_graph_id : int ,
1046+ ancestors_node_id_to_debug_handle : Dict [str , int ],
1047+ ) -> None :
1048+ """Apply debug handles to the exported program nodes."""
1049+
1050+ def _equip_debug_handle (node : Node ) -> None :
1051+ if node .op in ("output" , "placeholder" ):
1052+ return
1053+ node_id = f"{ node .name } .{ exported_program_graph_id } "
1054+ parent_node_id = get_parent_node_identifier (node )
1055+ if node_id in ancestors_node_id_to_debug_handle :
1056+ node .meta [DEBUG_HANDLE_KEY ] = ancestors_node_id_to_debug_handle [node_id ]
1057+ elif parent_node_id and parent_node_id in ancestors_node_id_to_debug_handle :
1058+ node .meta [DEBUG_HANDLE_KEY ] = ancestors_node_id_to_debug_handle [
1059+ parent_node_id
1060+ ]
1061+ else :
1062+ node .meta [DEBUG_HANDLE_KEY ] = UNSET_DEBUG_HANDLE
1063+
1064+ bfs_trace_with_node_process (exported_program .graph_module , _equip_debug_handle )
1065+
1066+
9391067def propagate_back_debug_handle (
9401068 exported_program : ExportedProgram ,
9411069 exported_program_graph_id : int ,
@@ -953,47 +1081,24 @@ def propagate_back_debug_handle(
9531081 Then debug handle of op1 should be same as op1_0, and debug handle of op3 should be same as op3_0 and op3_1.
9541082 The debug handle of op2 will be UNSET_DEBUG_HANDLE for further skipping.
9551083
956- Return: True if:
957- a. every debug handle in the edge dialect program has a corresponding node in the exported program
958- b. the exported program is the greatest ancestor of the edge dialect program
959-
960- Otherwise, return False.
1084+ Return: True if every debug handle in the edge dialect program has a corresponding node in the exported program, otherwise, return False.
9611085 """
1086+ # 1. Extract mapping from ancestor node identifiers to debug handles
1087+ ancestors_node_id_to_debug_handle = _extract_ancestor_debug_handles (
1088+ edge_dialect_program
1089+ )
9621090
963- # 1. set up a mapping from debug handle to identifier of export program's node
964- # using edge dialect program nodes' debug handles and from_node info
965- export_graph_node_id_to_debug_handle = {
966- get_greatest_ancestor_node_identifier (node ): node .meta [DEBUG_HANDLE_KEY ]
967- for node in edge_dialect_program .graph .nodes
968- if node .op not in ("placeholder" , "output" )
969- }
970-
971- # 2. equip debug handle to the exported program's nodes using the mapping
972- # number of nodes in the exported program that have matched entry in export_graph_node_id_to_debug_handle
973- n_matched_node = 0
974-
975- def _find_n_match_node (node : torch .fx .Node ) -> None :
976- nonlocal n_matched_node
977- if node .name in ("output" , "placeholder" ):
978- return
979- node_id = f"{ node .name } .{ exported_program_graph_id } "
980- if node_id in export_graph_node_id_to_debug_handle :
981- n_matched_node += 1
982-
983- def _equip_debug_handle (node : torch .fx .Node ) -> None :
984- if node .name in ("output" , "placeholder" ):
985- return
986- node_id = f"{ node .name } .{ exported_program_graph_id } "
987- if node_id in export_graph_node_id_to_debug_handle :
988- node .meta [DEBUG_HANDLE_KEY ] = export_graph_node_id_to_debug_handle [node_id ]
989- else :
990- node .meta [DEBUG_HANDLE_KEY ] = UNSET_DEBUG_HANDLE
991-
992- bfs_trace_with_node_process (exported_program .graph_module , _find_n_match_node )
1091+ # 2. Find debug handles that have corresponding nodes in the exported program
1092+ matched_debug_handles = _find_matched_debug_handles (
1093+ exported_program , exported_program_graph_id , ancestors_node_id_to_debug_handle
1094+ )
9931095
994- # if any node in the edge dialect program has no corresponding node in the exported program, match failed
995- if n_matched_node != len ( export_graph_node_id_to_debug_handle ):
1096+ # 3. Verify if every debug handle in edge dialect program has a corresponding node
1097+ if not _verify_graph_match ( edge_dialect_program , matched_debug_handles ):
9961098 return False
9971099
998- bfs_trace_with_node_process (exported_program .graph_module , _equip_debug_handle )
1100+ # 4. Apply debug handles to the exported program
1101+ _apply_debug_handles (
1102+ exported_program , exported_program_graph_id , ancestors_node_id_to_debug_handle
1103+ )
9991104 return True
0 commit comments