@@ -1161,12 +1161,15 @@ def _consume_etrecord(self) -> None:
1161
1161
)
1162
1162
1163
1163
# TODO: Make it more extensible to further merge overlapping debug handles
1164
- def _get_runtime_intermediate_outputs (self ) -> Dict [Tuple [int , ...], Any ]:
1164
+ def _get_runtime_intermediate_outputs_and_op_names (
1165
+ self ,
1166
+ ) -> Tuple [Dict [Tuple [int , ...], Any ], Dict [Tuple [int , ...], str ]]:
1165
1167
"""
1166
- Retrieve the raw runtime intermediate outputs(debug handles and value mappings)
1167
- from the event blocks. These outputs will be processed later to merge overlapping debug handles .
1168
+ Retrieve the runtime intermediate outputs(debug handles and intermediate values mappings)
1169
+ from the event blocks, along with the corresponding debug handles and op names mapping .
1168
1170
"""
1169
1171
debug_handle_to_output = {}
1172
+ debug_handle_to_op_name = {}
1170
1173
for event_block in self .event_blocks :
1171
1174
for event in event_block .events :
1172
1175
# Skip OPERATOR_CALL events to avoid double-counting and exclude framework tax
@@ -1175,20 +1178,23 @@ def _get_runtime_intermediate_outputs(self) -> Dict[Tuple[int, ...], Any]:
1175
1178
or not event .op_types
1176
1179
):
1177
1180
continue
1178
- # Normalize debug_handles to a tuple
1179
- debug_handles = event .debug_handles
1180
- if isinstance (debug_handles , int ):
1181
- debug_handles = (debug_handles ,)
1181
+ # Normalize debug_handle to a tuple
1182
+ debug_handle = event .debug_handles
1183
+ if isinstance (debug_handle , int ):
1184
+ debug_handle = (debug_handle ,)
1182
1185
else :
1183
- debug_handles = tuple (debug_handles )
1184
- current_entry = debug_handle_to_output .get (debug_handles , (- 1 , None ))
1185
- # When event has same debug handles , only keep the one with the largest instruction id
1186
+ debug_handle = tuple (debug_handle )
1187
+ current_entry = debug_handle_to_output .get (debug_handle , (- 1 , None ))
1188
+ # When event has same debug_handle , only keep the one with the largest instruction id
1186
1189
if event ._instruction_id > current_entry [0 ]:
1187
- debug_handle_to_output [debug_handles ] = (
1190
+ debug_handle_to_output [debug_handle ] = (
1188
1191
event ._instruction_id ,
1189
1192
event .debug_data ,
1190
1193
)
1191
- return {k : v [1 ] for k , v in debug_handle_to_output .items ()}
1194
+ debug_handle_to_op_name [debug_handle ] = event .name
1195
+ return {
1196
+ k : v [1 ] for k , v in debug_handle_to_output .items ()
1197
+ }, debug_handle_to_op_name
1192
1198
1193
1199
def to_dataframe (
1194
1200
self ,
@@ -1364,8 +1370,12 @@ def calculate_numeric_gap(self, distance: str = "MSE") -> pd.DataFrame:
1364
1370
raise ValueError (
1365
1371
"The aot intermediate outputs is required but not populated."
1366
1372
)
1373
+ # The runtime_op_names will be used later to map runtime debug_handle to op_name
1374
+ runtime_intermediate_outputs , runtime_op_names = (
1375
+ self ._get_runtime_intermediate_outputs_and_op_names ()
1376
+ )
1367
1377
mapping = map_runtime_aot_intermediate_outputs (
1368
- self ._aot_intermediate_outputs , self . _get_runtime_intermediate_outputs ()
1378
+ self ._aot_intermediate_outputs , runtime_intermediate_outputs
1369
1379
)
1370
1380
metric = distance .strip ().upper ()
1371
1381
if metric == "MSE" :
0 commit comments