@@ -1161,12 +1161,16 @@ def _consume_etrecord(self) -> None:
11611161 )
11621162
11631163 # TODO: Make it more extensible to further merge overlapping debug handles
1164- def _get_runtime_intermediate_outputs (self ) -> Dict [Tuple [int , ...], Any ]:
1164+ def _get_runtime_intermediate_outputs_and_op_names (
1165+ self ,
1166+ ) -> Tuple [Dict [Tuple [int , ...], Any ], Dict [Tuple [int , ...], str ]]:
11651167 """
1166- Retrieve the raw runtime intermediate outputs(debug handles and value mappings)
1167- from the event blocks. These outputs will be processed later to merge overlapping debug handles.
1168+ Retrieve the runtime intermediate outputs(debug handles and intermediate values mappings)
1169+ from the event blocks, along with the corresponding debug handles and op names mapping
1170+ These outputs will be processed later to merge overlapping debug handles.
11681171 """
11691172 debug_handle_to_output = {}
1173+ debug_handle_to_op_name = {}
11701174 for event_block in self .event_blocks :
11711175 for event in event_block .events :
11721176 # Skip OPERATOR_CALL events to avoid double-counting and exclude framework tax
@@ -1175,20 +1179,23 @@ def _get_runtime_intermediate_outputs(self) -> Dict[Tuple[int, ...], Any]:
11751179 or not event .op_types
11761180 ):
11771181 continue
1178- # Normalize debug_handles to a tuple
1179- debug_handles = event .debug_handles
1180- if isinstance (debug_handles , int ):
1181- debug_handles = (debug_handles ,)
1182+ # Normalize debug_handle to a tuple
1183+ debug_handle = event .debug_handles
1184+ if isinstance (debug_handle , int ):
1185+ debug_handle = (debug_handle ,)
11821186 else :
1183- debug_handles = tuple (debug_handles )
1184- current_entry = debug_handle_to_output .get (debug_handles , (- 1 , None ))
1185- # When event has same debug handles , only keep the one with the largest instruction id
1187+ debug_handle = tuple (debug_handle )
1188+ current_entry = debug_handle_to_output .get (debug_handle , (- 1 , None ))
1189+ # When event has same debug_handle , only keep the one with the largest instruction id
11861190 if event ._instruction_id > current_entry [0 ]:
1187- debug_handle_to_output [debug_handles ] = (
1191+ debug_handle_to_output [debug_handle ] = (
11881192 event ._instruction_id ,
11891193 event .debug_data ,
11901194 )
1191- return {k : v [1 ] for k , v in debug_handle_to_output .items ()}
1195+ debug_handle_to_op_name [debug_handle ] = event .name
1196+ return {
1197+ k : v [1 ] for k , v in debug_handle_to_output .items ()
1198+ }, debug_handle_to_op_name
11921199
11931200 def to_dataframe (
11941201 self ,
@@ -1364,8 +1371,12 @@ def calculate_numeric_gap(self, distance: str = "MSE") -> pd.DataFrame:
13641371 raise ValueError (
13651372 "The aot intermediate outputs is required but not populated."
13661373 )
1374+ # The runtime_op_names will be used later to map runtime debug_handle to op_name
1375+ runtime_intermediate_outputs , runtime_op_names = (
1376+ self ._get_runtime_intermediate_outputs_and_op_names ()
1377+ )
13671378 mapping = map_runtime_aot_intermediate_outputs (
1368- self ._aot_intermediate_outputs , self . _get_runtime_intermediate_outputs ()
1379+ self ._aot_intermediate_outputs , runtime_intermediate_outputs
13691380 )
13701381 metric = distance .strip ().upper ()
13711382 if metric == "MSE" :
0 commit comments