@@ -1161,12 +1161,15 @@ def _consume_etrecord(self) -> None:
11611161 )
11621162
11631163 # TODO: Make it more extensible to further merge overlapping debug handles
1164- def _get_runtime_intermediate_outputs (self ) -> Dict [Tuple [int , ...], Any ]:
1164+ def _get_runtime_intermediate_outputs_and_op_names (
1165+ self ,
1166+ ) -> Tuple [Dict [Tuple [int , ...], Any ], Dict [Tuple [int , ...], str ]]:
11651167 """
1166- Retrieve the raw runtime intermediate outputs(debug handles and value mappings)
1167- from the event blocks. These outputs will be processed later to merge overlapping debug handles .
1168+ Retrieve the runtime intermediate outputs(debug handles and intermediate values mappings)
1169+ from the event blocks, along with the corresponding debug handles and op names mapping .
11681170 """
11691171 debug_handle_to_output = {}
1172+ debug_handle_to_op_name = {}
11701173 for event_block in self .event_blocks :
11711174 for event in event_block .events :
11721175 # Skip OPERATOR_CALL events to avoid double-counting and exclude framework tax
@@ -1175,20 +1178,23 @@ def _get_runtime_intermediate_outputs(self) -> Dict[Tuple[int, ...], Any]:
11751178 or not event .op_types
11761179 ):
11771180 continue
1178- # Normalize debug_handles to a tuple
1179- debug_handles = event .debug_handles
1180- if isinstance (debug_handles , int ):
1181- debug_handles = (debug_handles ,)
1181+ # Normalize debug_handle to a tuple
1182+ debug_handle = event .debug_handles
1183+ if isinstance (debug_handle , int ):
1184+ debug_handle = (debug_handle ,)
11821185 else :
1183- debug_handles = tuple (debug_handles )
1184- current_entry = debug_handle_to_output .get (debug_handles , (- 1 , None ))
1185- # When event has same debug handles , only keep the one with the largest instruction id
1186+ debug_handle = tuple (debug_handle )
1187+ current_entry = debug_handle_to_output .get (debug_handle , (- 1 , None ))
1188+ # When event has same debug_handle , only keep the one with the largest instruction id
11861189 if event ._instruction_id > current_entry [0 ]:
1187- debug_handle_to_output [debug_handles ] = (
1190+ debug_handle_to_output [debug_handle ] = (
11881191 event ._instruction_id ,
11891192 event .debug_data ,
11901193 )
1191- return {k : v [1 ] for k , v in debug_handle_to_output .items ()}
1194+ debug_handle_to_op_name [debug_handle ] = event .name
1195+ return {
1196+ k : v [1 ] for k , v in debug_handle_to_output .items ()
1197+ }, debug_handle_to_op_name
11921198
11931199 def to_dataframe (
11941200 self ,
@@ -1364,8 +1370,12 @@ def calculate_numeric_gap(self, distance: str = "MSE") -> pd.DataFrame:
13641370 raise ValueError (
13651371 "The aot intermediate outputs is required but not populated."
13661372 )
1373+ # The runtime_op_names will be used later to map runtime debug_handle to op_name
1374+ runtime_intermediate_outputs , runtime_op_names = (
1375+ self ._get_runtime_intermediate_outputs_and_op_names ()
1376+ )
13671377 mapping = map_runtime_aot_intermediate_outputs (
1368- self ._aot_intermediate_outputs , self . _get_runtime_intermediate_outputs ()
1378+ self ._aot_intermediate_outputs , runtime_intermediate_outputs
13691379 )
13701380 metric = distance .strip ().upper ()
13711381 if metric == "MSE" :
0 commit comments