@@ -657,13 +657,21 @@ def _combine_aot_overlapped_intermediate_outputs(
657
657
# Combine all AOT debug_handles into a list
658
658
aot_combined_debug_handle = [t [0 ] for t in aot_map .keys ()]
659
659
660
- if set (aot_combined_debug_handle ) != set (runtime_debug_handle ):
661
- # AOT combined debug_handle and runtime debug_handle do not match.
660
+ # Reason we dont check for exact match:
661
+ # in some experiments where we want to rewrite the aten graph that was
662
+ # lowered, so as to use custom ops like int4_matmul, we lose some nodes
663
+ # on the graph and thus lose some debug handles. And we dont find
664
+ # exact match within connected components.
665
+ if not set (aot_combined_debug_handle ).issubset (set (runtime_debug_handle )):
666
+ # AOT combined debug_handle is not a subset of runtime debug_handle.
662
667
return (- 1 ,), None
663
668
664
669
# Pick the last intermediate output
665
670
last_int = runtime_debug_handle [negative_index ]
666
671
key = (last_int ,)
672
+ if key not in aot_map :
673
+ # If the last intermediate output is not in the AOT map, return None
674
+ return (- 1 ,), None
667
675
return runtime_debug_handle , aot_map [key ]
668
676
669
677
@@ -1059,11 +1067,16 @@ def _find_n_match_node(node: Node) -> None:
1059
1067
if node .op in ("output" , "placeholder" ):
1060
1068
return
1061
1069
node_id = f"{ node .name } .{ exported_program_graph_id } "
1062
- parent_node_id = get_parent_node_identifier (node )
1070
+ parent_node_ids = get_ancestor_node_identifiers (node )
1063
1071
if node_id in ancestors_node_id_to_debug_handle :
1064
1072
matched_debug_handles .add (ancestors_node_id_to_debug_handle [node_id ])
1065
- elif parent_node_id and parent_node_id in ancestors_node_id_to_debug_handle :
1066
- matched_debug_handles .add (ancestors_node_id_to_debug_handle [parent_node_id ])
1073
+ elif parent_node_ids :
1074
+ for parent_node_id in parent_node_ids :
1075
+ if parent_node_id in ancestors_node_id_to_debug_handle :
1076
+ matched_debug_handles .add (
1077
+ ancestors_node_id_to_debug_handle [parent_node_id ]
1078
+ )
1079
+ break
1067
1080
1068
1081
bfs_trace_with_node_process (exported_program .graph_module , _find_n_match_node )
1069
1082
return matched_debug_handles
@@ -1097,15 +1110,17 @@ def _equip_debug_handle(node: Node) -> None:
1097
1110
if node .op in ("output" , "placeholder" ):
1098
1111
return
1099
1112
node_id = f"{ node .name } .{ exported_program_graph_id } "
1100
- parent_node_id = get_parent_node_identifier (node )
1113
+ parent_node_ids = get_ancestor_node_identifiers (node )
1114
+ node .meta [DEBUG_HANDLE_KEY ] = UNSET_DEBUG_HANDLE
1101
1115
if node_id in ancestors_node_id_to_debug_handle :
1102
1116
node .meta [DEBUG_HANDLE_KEY ] = ancestors_node_id_to_debug_handle [node_id ]
1103
- elif parent_node_id and parent_node_id in ancestors_node_id_to_debug_handle :
1104
- node .meta [DEBUG_HANDLE_KEY ] = ancestors_node_id_to_debug_handle [
1105
- parent_node_id
1106
- ]
1107
- else :
1108
- node .meta [DEBUG_HANDLE_KEY ] = UNSET_DEBUG_HANDLE
1117
+ elif parent_node_ids :
1118
+ for parent_node_id in parent_node_ids :
1119
+ if parent_node_id in ancestors_node_id_to_debug_handle :
1120
+ node .meta [DEBUG_HANDLE_KEY ] = ancestors_node_id_to_debug_handle [
1121
+ parent_node_id
1122
+ ]
1123
+ break
1109
1124
1110
1125
bfs_trace_with_node_process (exported_program .graph_module , _equip_debug_handle )
1111
1126
0 commit comments