@@ -657,13 +657,21 @@ def _combine_aot_overlapped_intermediate_outputs(
657657 # Combine all AOT debug_handles into a list
658658 aot_combined_debug_handle = [t [0 ] for t in aot_map .keys ()]
659659
660- if set (aot_combined_debug_handle ) != set (runtime_debug_handle ):
661- # AOT combined debug_handle and runtime debug_handle do not match.
660+ # Reason we dont check for exact match:
661+ # in some experiments where we want to rewrite the aten graph that was
662+ # lowered, so as to use custom ops like int4_matmul, we lose some nodes
663+ # on the graph and thus lose some debug handles. And we dont find
664+ # exact match within connected components.
665+ if not set (aot_combined_debug_handle ).issubset (set (runtime_debug_handle )):
666+ # AOT combined debug_handle is not a subset of runtime debug_handle.
662667 return (- 1 ,), None
663668
664669 # Pick the last intermediate output
665670 last_int = runtime_debug_handle [negative_index ]
666671 key = (last_int ,)
672+ if key not in aot_map :
673+ # If the last intermediate output is not in the AOT map, return None
674+ return (- 1 ,), None
667675 return runtime_debug_handle , aot_map [key ]
668676
669677
@@ -1059,11 +1067,16 @@ def _find_n_match_node(node: Node) -> None:
10591067 if node .op in ("output" , "placeholder" ):
10601068 return
10611069 node_id = f"{ node .name } .{ exported_program_graph_id } "
1062- parent_node_id = get_parent_node_identifier (node )
1070+ parent_node_ids = get_ancestor_node_identifiers (node )
10631071 if node_id in ancestors_node_id_to_debug_handle :
10641072 matched_debug_handles .add (ancestors_node_id_to_debug_handle [node_id ])
1065- elif parent_node_id and parent_node_id in ancestors_node_id_to_debug_handle :
1066- matched_debug_handles .add (ancestors_node_id_to_debug_handle [parent_node_id ])
1073+ elif parent_node_ids :
1074+ for parent_node_id in parent_node_ids :
1075+ if parent_node_id in ancestors_node_id_to_debug_handle :
1076+ matched_debug_handles .add (
1077+ ancestors_node_id_to_debug_handle [parent_node_id ]
1078+ )
1079+ break
10671080
10681081 bfs_trace_with_node_process (exported_program .graph_module , _find_n_match_node )
10691082 return matched_debug_handles
@@ -1097,15 +1110,17 @@ def _equip_debug_handle(node: Node) -> None:
10971110 if node .op in ("output" , "placeholder" ):
10981111 return
10991112 node_id = f"{ node .name } .{ exported_program_graph_id } "
1100- parent_node_id = get_parent_node_identifier (node )
1113+ parent_node_ids = get_ancestor_node_identifiers (node )
1114+ node .meta [DEBUG_HANDLE_KEY ] = UNSET_DEBUG_HANDLE
11011115 if node_id in ancestors_node_id_to_debug_handle :
11021116 node .meta [DEBUG_HANDLE_KEY ] = ancestors_node_id_to_debug_handle [node_id ]
1103- elif parent_node_id and parent_node_id in ancestors_node_id_to_debug_handle :
1104- node .meta [DEBUG_HANDLE_KEY ] = ancestors_node_id_to_debug_handle [
1105- parent_node_id
1106- ]
1107- else :
1108- node .meta [DEBUG_HANDLE_KEY ] = UNSET_DEBUG_HANDLE
1117+ elif parent_node_ids :
1118+ for parent_node_id in parent_node_ids :
1119+ if parent_node_id in ancestors_node_id_to_debug_handle :
1120+ node .meta [DEBUG_HANDLE_KEY ] = ancestors_node_id_to_debug_handle [
1121+ parent_node_id
1122+ ]
1123+ break
11091124
11101125 bfs_trace_with_node_process (exported_program .graph_module , _equip_debug_handle )
11111126
0 commit comments