1111from collections .abc import Sequence
1212from dataclasses import dataclass
1313from enum import Enum
14- from typing import Any , Dict , IO , List , Mapping , Optional , Tuple , TypeAlias , Union
14+ from typing import Any , Dict , IO , List , Mapping , Optional , Set , Tuple , TypeAlias , Union
1515
1616import executorch .devtools .etdump .schema_flatcc as flatcc
1717
3737
3838from executorch .exir .debug_handle_utils import (
3939 DEBUG_HANDLE_KEY ,
40+ FROM_NODE_KEY ,
4041 get_greatest_ancestor_node_identifier ,
4142 UNSET_DEBUG_HANDLE ,
4243)
4647from tabulate import tabulate
4748
4849from torch .export import ExportedProgram
50+ from torch .fx import Node
4951
5052FORWARD = "forward"
5153EDGE_DIALECT_GRAPH_KEY = "edge_dialect_graph_module"
@@ -936,6 +938,44 @@ def compare_intermediate_outputs(a: Any, b: Any, comparator) -> List[float]:
936938 )
937939
938940
941+ def get_ancestor_node_identifiers (node : Node ) -> List [str ]:
942+ """Get the identifier of the ancestor node of the given node, with the graph id the ancestor node lives in.
943+
944+ The identifier is the concatenation of the node name and graph id of the
945+ greatest ancestor node, where the graph id is the unique id for every graph
946+ module in the export flow and node name is unique within the same graph module.
947+
948+ Returns: the identifiers of all its ancestor nodes
949+ """
950+
951+ node_source = node .meta [FROM_NODE_KEY ]
952+ node_source = node_source [- 1 ]
953+ ancestor_node_ids : List [str ] = [f"{ node_source .name } .{ str (node_source .graph_id )} " ]
954+
955+ while len (node_source .from_node ) > 0 :
956+ node_source = node_source .from_node [- 1 ]
957+ ancestor_node_ids .append (f"{ node_source .name } .{ str (node_source .graph_id )} " )
958+
959+ return ancestor_node_ids
960+
961+
962+ def get_parent_node_identifier (node : Node ) -> Optional [str ]:
963+ """Get the identifier of the parent node of the given node, with the graph id the parent node lives in.
964+
965+ The identifier is the concatenation of the node name and graph id of the
966+ greatest parent node, where the graph id is the unique id for every graph
967+ module in the export flow and node name is unique within the same graph module.
968+
969+ Returns: the identifier of the parent node, or None if can not find the parent
970+ """
971+
972+ if FROM_NODE_KEY not in node .meta :
973+ return None
974+
975+ node_source = node .meta [FROM_NODE_KEY ][- 1 ]
976+ return f"{ node_source .name } .{ str (node_source .graph_id )} "
977+
978+
939979def propagate_back_debug_handle (
940980 exported_program : ExportedProgram ,
941981 exported_program_graph_id : int ,
@@ -953,47 +993,78 @@ def propagate_back_debug_handle(
953993 Then debug handle of op1 should be same as op1_0, and debug handle of op3 should be same as op3_0 and op3_1.
954994 The debug handle of op2 will be UNSET_DEBUG_HANDLE for further skipping.
955995
956- Return: True if:
957- a. every debug handle in the edge dialect program has a corresponding node in the exported program
958- b. the exported program is the greatest ancestor of the edge dialect program
959-
960- Otherwise, return False.
996+ Return: True if every debug handle in the edge dialect program has a corresponding node in the exported program, otherwise, return False.
961997 """
962998
963- # 1. set up a mapping from debug handle to identifier of export program's node
999+ # 1. set up a mapping from identifier of every possible ancestor node id to debug handle
9641000 # using edge dialect program nodes' debug handles and from_node info
965- export_graph_node_id_to_debug_handle = {
966- get_greatest_ancestor_node_identifier (node ): node .meta [DEBUG_HANDLE_KEY ]
967- for node in edge_dialect_program .graph .nodes
968- if node .op not in ("placeholder" , "output" )
969- }
970-
971- # 2. equip debug handle to the exported program's nodes using the mapping
972- # number of nodes in the exported program that have matched entry in export_graph_node_id_to_debug_handle
973- n_matched_node = 0
1001+ ancestors_node_id_to_debug_handle : Dict [str , int ] = {}
9741002
975- def _find_n_match_node (node : torch . fx . Node ) -> None :
976- nonlocal n_matched_node
977- if node .name in ("output " , "placeholder " ):
1003+ def _extract_node_id_to_debug_handle (node : Node ) -> None :
1004+ nonlocal ancestors_node_id_to_debug_handle
1005+ if node .op in ("placeholder " , "output " ):
9781006 return
979- node_id = f"{ node .name } .{ exported_program_graph_id } "
980- if node_id in export_graph_node_id_to_debug_handle :
981- n_matched_node += 1
1007+ for ancestor_node_id in get_ancestor_node_identifiers (node ):
1008+ if ancestor_node_id not in ancestors_node_id_to_debug_handle :
1009+ ancestors_node_id_to_debug_handle [ancestor_node_id ] = node .meta [
1010+ DEBUG_HANDLE_KEY
1011+ ]
1012+ else :
1013+ assert (
1014+ ancestors_node_id_to_debug_handle [ancestor_node_id ]
1015+ == node .meta [DEBUG_HANDLE_KEY ]
1016+ )
1017+
1018+ bfs_trace_with_node_process (
1019+ edge_dialect_program .graph_module , _extract_node_id_to_debug_handle
1020+ )
9821021
983- def _equip_debug_handle (node : torch .fx .Node ) -> None :
984- if node .name in ("output" , "placeholder" ):
1022+ # 2. verify if every debug handle in the edge dialect program has a corresponding node in the exported program
1023+ matched_debug_handes : Set [int ] = set ()
1024+
1025+ def _find_n_match_node (node : Node ) -> None :
1026+ nonlocal matched_debug_handes
1027+ if node .op in ("output" , "placeholder" ):
9851028 return
9861029 node_id = f"{ node .name } .{ exported_program_graph_id } "
987- if node_id in export_graph_node_id_to_debug_handle :
988- node .meta [DEBUG_HANDLE_KEY ] = export_graph_node_id_to_debug_handle [node_id ]
989- else :
990- node .meta [DEBUG_HANDLE_KEY ] = UNSET_DEBUG_HANDLE
1030+ parent_node_id = get_parent_node_identifier (node )
1031+ if node_id in ancestors_node_id_to_debug_handle :
1032+ matched_debug_handes .add (ancestors_node_id_to_debug_handle [node_id ])
1033+ elif parent_node_id and parent_node_id in ancestors_node_id_to_debug_handle :
1034+ matched_debug_handes .add (ancestors_node_id_to_debug_handle [parent_node_id ])
9911035
9921036 bfs_trace_with_node_process (exported_program .graph_module , _find_n_match_node )
9931037
1038+ graph_matched = True
1039+
1040+ def _check_graph_match (node : Node ) -> None :
1041+ nonlocal graph_matched
1042+ if node .op in ("output" , "placeholder" ):
1043+ return
1044+
1045+ if node .meta [DEBUG_HANDLE_KEY ] not in matched_debug_handes :
1046+ graph_matched = False
1047+
1048+ bfs_trace_with_node_process (edge_dialect_program .graph_module , _check_graph_match )
1049+
9941050 # if any node in the edge dialect program has no corresponding node in the exported program, match failed
995- if n_matched_node != len ( export_graph_node_id_to_debug_handle ) :
1051+ if not graph_matched :
9961052 return False
9971053
1054+ # 3. propagate debug handle from edge dialect program back to the exported program while maintain the correctness of operator tracing
1055+ def _equip_debug_handle (node : Node ) -> None :
1056+ if node .op in ("output" , "placeholder" ):
1057+ return
1058+ node_id = f"{ node .name } .{ exported_program_graph_id } "
1059+ parent_node_id = get_parent_node_identifier (node )
1060+ if node_id in ancestors_node_id_to_debug_handle :
1061+ node .meta [DEBUG_HANDLE_KEY ] = ancestors_node_id_to_debug_handle [node_id ]
1062+ elif parent_node_id and parent_node_id in ancestors_node_id_to_debug_handle :
1063+ node .meta [DEBUG_HANDLE_KEY ] = ancestors_node_id_to_debug_handle [
1064+ parent_node_id
1065+ ]
1066+ else :
1067+ node .meta [DEBUG_HANDLE_KEY ] = UNSET_DEBUG_HANDLE
1068+
9981069 bfs_trace_with_node_process (exported_program .graph_module , _equip_debug_handle )
9991070 return True
0 commit comments