@@ -538,49 +538,71 @@ def compare_results(
538538 return results
539539
540540
541- def merge_overlapping_debug_handles (
542- intermediate_outputs : Dict [ DebugHandle , Any ]
543- ) -> Dict [ DebugHandle , Any ] :
541+ def _merge_runtime_debug_handles (
542+ debug_handle1 : DebugHandle , debug_handle2 : DebugHandle
543+ ) -> DebugHandle :
544544 """
545- Merges overlapping debug handles into a single key in the dict.
546- For each debug handle, this function checks for overlaps with existing keys in the merged dict.
547- If overlaps are found, it combines the overlapping keys into a single key by taking the union of their elements.
548- The value associated with the merged key is determined by the debug handle with the highest last element.
545+ Merge two DebugHandles by removing elements from debug_handle1 that are also present in debug_handle2,
546+ while preserving the relative order of elements in both modified debug_handle1 and debug_handle2.
547+ All elements from the modified debug_handle1 will appear before any elements from debug_handle2.
549548 """
550549
551- if len (intermediate_outputs ) == 0 :
552- return {}
550+ # Initialize a list to store unique elements in order
551+ unique_ordered_list = []
552+
553+ # Initialize a set to track elements that have already been seen
554+ seen = set (debug_handle2 )
555+
556+ for item in debug_handle1 :
557+ # If the element has not been seen before, add it to the list and mark it as seen
558+ if item not in seen :
559+ unique_ordered_list .append (item )
553560
554- merged : Dict [DebugHandle , Any ] = {}
561+ for item in debug_handle2 :
562+ unique_ordered_list .append (item )
563+ return tuple (unique_ordered_list )
555564
556- for debug_handle , value in intermediate_outputs .items ():
557- debug_handle_set = set (debug_handle )
558- curr_debug_handle , last_value = debug_handle , value
559565
560- # collect any existing keys that overlap with the current key
566+ def merge_runtime_overlapping_debug_handles (
567+ intermediate_outputs : Dict [DebugHandle , Tuple [int , Any ]]
568+ ) -> Dict [DebugHandle , Tuple [int , Any ]]:
569+ """
570+ Merges runtimes with overlapping debug handles into a single key in the dict.
571+
572+ For each debug handle, this function checks for overlaps with existing keys.
573+ If overlaps are found, it combines the overlapping keys into a single key by taking
574+ the union of their elements while maintaining the order. The order is preserved such that
575+ higher instruction_id appears after the debug_handle with lower instruction_id.
576+
577+ The value associated with the merged key is determined by the debug handle with the highest instruction id.
578+ """
579+ if len (intermediate_outputs ) == 0 :
580+ return {}
581+ merged : Dict [DebugHandle , Tuple [int , Any ]] = {}
582+ for debug_handle , (instruction_id , debug_data ) in intermediate_outputs .items ():
583+ curr_debug_handle , last_value = debug_handle , (instruction_id , debug_data )
584+ # Collect any existing keys that overlap with the current key
561585 to_remove = []
562586 for existing_debug_handle , existing_value in merged .items ():
563- if debug_handle_set .intersection (set (existing_debug_handle )):
564- # abosrb their ints
565- debug_handle_set |= set (existing_debug_handle )
566- if existing_debug_handle [- 1 ] > curr_debug_handle [- 1 ]:
567- curr_debug_handle , last_value = (
568- existing_debug_handle ,
569- existing_value ,
587+ if any (item in existing_debug_handle for item in debug_handle ):
588+ # Keep the value with the highest instruction_id
589+ # Also merge the debug handles higher instruction_id
590+ if existing_value [0 ] < instruction_id :
591+ curr_debug_handle = _merge_runtime_debug_handles (
592+ existing_debug_handle , curr_debug_handle
570593 )
594+ else :
595+ curr_debug_handle = _merge_runtime_debug_handles (
596+ curr_debug_handle , existing_debug_handle
597+ )
598+ last_value = existing_value
571599 to_remove .append (existing_debug_handle )
572-
573- # remove all the keys that overlap with the current key
600+ # Remove all the keys that overlap with the current key
574601 for debug_handle in to_remove :
575602 merged .pop (debug_handle )
576-
577- # add the current key to the merged one
578- new_debug_handle = tuple (sorted (debug_handle_set ))
579- merged [new_debug_handle ] = last_value
580-
581- # Sort the merged debug handles in ascending order based on their last element
582- # TODO: Consider adding more logic to align the order with the execution order
583- return dict (sorted (merged .items (), key = lambda item : item [0 ][- 1 ]))
603+ # Add the current key to the merged one
604+ merged [curr_debug_handle ] = last_value
605+ return merged
584606
585607
586608def _debug_handles_have_overlap (
@@ -696,12 +718,6 @@ def map_runtime_aot_intermediate_outputs(
696718 Dict[Tuple[DebugHandle, Any], Tuple[DebugHandle, Any]] - Mapping
697719 from runtime intermediate output to AOT intermediate output
698720 """
699- # Merge overlapping debug handles
700- aot_intermediate_outputs = merge_overlapping_debug_handles (aot_intermediate_outputs )
701- runtime_intermediate_outputs = merge_overlapping_debug_handles (
702- runtime_intermediate_outputs
703- )
704-
705721 # Create a graph(nodes and edges) of overlapping(between aot and runtime) debug handles
706722 nodes , edges = _create_debug_handle_overlap_graph (
707723 aot_intermediate_outputs , runtime_intermediate_outputs
0 commit comments