@@ -530,47 +530,61 @@ def compare_results(
530530 return results
531531
532532
533- def merge_overlapping_debug_handles (intermediate_outputs : Dict [DebugHandle , Any ]):
533+ def merge_overlapping_debug_handles (intermediate_outputs : Dict [DebugHandle , Any ]) -> Dict [ DebugHandle , Any ] :
534534 """
535- Merge overlapping debug handles int a single key
535+ Merges overlapping debug handles into a single key in the dict.
536+ For each debug handle, this function checks for overlaps with existing keys in the merged dict.
537+ If overlaps are found, it combines the overlapping keys into a single key by taking the union of their elements.
538+ The value associated with the merged key is determined by the debug handle with the highest last element.
536539 """
540+
537541 if len (intermediate_outputs ) == 0 :
538- return
539- # Extract and normalize into (start, end, val)
540- intervals = [(min (key ), max (key ), val ) for key , val in intermediate_outputs .items ()]
541- intervals .sort (key = lambda x : x [0 ])
542-
543- # Merge overlapping debug_hanldes, picking the last value
544- merged_intermediate_outputs = []
545- cur_start , cur_end , cur_val = intervals [0 ]
546- for start , end , val in intervals [1 :]:
547- if start <= cur_end : # Overlaps
548- if end > cur_end : # Extend if this one goes further
549- cur_end , cur_val = end , val
542+ return {}
550543
551- else :
552- merged_intermediate_outputs .append ((cur_start , cur_end , cur_val ))
553- cur_start , cur_end , cur_val = start , end , val
554- merged_intermediate_outputs .append ((cur_start , cur_end , cur_val ))
544+ merged : Dict [DebugHandle , Any ] = {}
545+
546+ for debug_handle , value in intermediate_outputs .items ():
547+ debug_handle_set = set (debug_handle )
548+ curr_debug_handle , last_value = debug_handle , value
549+
550+ # collect any existing keys that overlap with the current key
551+ to_remove = []
552+ for existing_debug_handle , existing_value in merged .items ():
553+ if debug_handle_set .intersection (set (existing_debug_handle )):
554+ # abosrb their ints
555+ debug_handle_set |= set (existing_debug_handle )
556+ if existing_debug_handle [- 1 ] > curr_debug_handle [- 1 ]:
557+ curr_debug_handle , last_value = (
558+ existing_debug_handle ,
559+ existing_value ,
560+ )
561+ to_remove .append (existing_debug_handle )
555562
556- # Clear original one and populate with merged keys (value will point to the same object)
557- intermediate_outputs .clear ()
558- for start , end , val in merged_intermediate_outputs :
559- intermediate_outputs [tuple (range (start , end + 1 ))] = val
563+ # remove all the keys that overlap with the current key
564+ for debug_handle in to_remove :
565+ merged .pop (debug_handle )
566+
567+ # add the current key to the merged one
568+ new_debug_handle = tuple (sorted (debug_handle_set ))
569+ merged [new_debug_handle ] = last_value
570+
571+ # Sort the merged debug handles in ascending order based on their last element
572+ # TODO: Consider adding more logic to align the order with the execution order
573+ return dict (sorted (merged .items (), key = lambda item : item [0 ][- 1 ]))
560574
561575
562576def _debug_handles_have_overlap (
563- aot_debug_hanlde : DebugHandle , runtime_debug_handle : DebugHandle
577+ debug_handle : DebugHandle , target_debug_handle : DebugHandle
564578) -> bool :
565579 """
566- Check if the AOT debug handle and the runtime debug handle have any overlap.
580+ Check if the debug handle and the target runtime debug handle have any overlap.
567581 """
568- aot_set = set (aot_debug_hanlde )
569- runtime_set = set (runtime_debug_handle )
582+ aot_set = set (debug_handle )
583+ runtime_set = set (target_debug_handle )
570584 return len (aot_set .intersection (runtime_set )) > 0
571585
572586
573- def _combine_debug_hanldes (debug_handles : List [DebugHandle ]) -> DebugHandle :
587+ def _combine_debug_handles (debug_handles : List [DebugHandle ]) -> DebugHandle :
574588 """Combine multiple debug handles into one debug handle"""
575589 combined_debug_handles_set = set ()
576590 for debug_handle in debug_handles :
@@ -584,7 +598,7 @@ def _combine_overlapped_intermediate_outputs(
584598 """Combine multiple overlapped intermediate outputs into one with combined debug_handles and last output"""
585599 debug_handles = [debug_handle for debug_handle , _ in nodes ]
586600 outputs = [output for _ , output in nodes ]
587- combined_debug_handle = _combine_debug_hanldes (debug_handles )
601+ combined_debug_handle = _combine_debug_handles (debug_handles )
588602 output = outputs [- 1 ] # Pick the last one
589603 return combined_debug_handle , output
590604
@@ -673,8 +687,10 @@ def map_runtime_aot_intermediate_outputs(
673687 from runtime intermediate output to AOT intermediate output
674688 """
675689 # Merge overlapping debug handles
676- merge_overlapping_debug_handles (aot_intermediate_outputs )
677- merge_overlapping_debug_handles (runtime_intermediate_outputs )
690+ aot_intermediate_outputs = merge_overlapping_debug_handles (aot_intermediate_outputs )
691+ runtime_intermediate_outputs = merge_overlapping_debug_handles (
692+ runtime_intermediate_outputs
693+ )
678694
679695 # Create a graph(nodes and edges) of overlapping(between aot and runtime) debug handles
680696 nodes , edges = _create_debug_handle_overlap_graph (
0 commit comments