@@ -530,47 +530,63 @@ def compare_results(
530
530
return results
531
531
532
532
533
- def merge_overlapping_debug_handles (intermediate_outputs : Dict [DebugHandle , Any ]):
533
+ def merge_overlapping_debug_handles (
534
+ intermediate_outputs : Dict [DebugHandle , Any ]
535
+ ) -> Dict [DebugHandle , Any ]:
534
536
"""
535
- Merge overlapping debug handles int a single key
537
+ Merges overlapping debug handles into a single key in the dict.
538
+ For each debug handle, this function checks for overlaps with existing keys in the merged dict.
539
+ If overlaps are found, it combines the overlapping keys into a single key by taking the union of their elements.
540
+ The value associated with the merged key is determined by the debug handle with the highest last element.
536
541
"""
542
+
537
543
if len (intermediate_outputs ) == 0 :
538
- return
539
- # Extract and normalize into (start, end, val)
540
- intervals = [(min (key ), max (key ), val ) for key , val in intermediate_outputs .items ()]
541
- intervals .sort (key = lambda x : x [0 ])
542
-
543
- # Merge overlapping debug_hanldes, picking the last value
544
- merged_intermediate_outputs = []
545
- cur_start , cur_end , cur_val = intervals [0 ]
546
- for start , end , val in intervals [1 :]:
547
- if start <= cur_end : # Overlaps
548
- if end > cur_end : # Extend if this one goes further
549
- cur_end , cur_val = end , val
544
+ return {}
550
545
551
- else :
552
- merged_intermediate_outputs .append ((cur_start , cur_end , cur_val ))
553
- cur_start , cur_end , cur_val = start , end , val
554
- merged_intermediate_outputs .append ((cur_start , cur_end , cur_val ))
546
+ merged : Dict [DebugHandle , Any ] = {}
547
+
548
+ for debug_handle , value in intermediate_outputs .items ():
549
+ debug_handle_set = set (debug_handle )
550
+ curr_debug_handle , last_value = debug_handle , value
551
+
552
+ # collect any existing keys that overlap with the current key
553
+ to_remove = []
554
+ for existing_debug_handle , existing_value in merged .items ():
555
+ if debug_handle_set .intersection (set (existing_debug_handle )):
556
+ # abosrb their ints
557
+ debug_handle_set |= set (existing_debug_handle )
558
+ if existing_debug_handle [- 1 ] > curr_debug_handle [- 1 ]:
559
+ curr_debug_handle , last_value = (
560
+ existing_debug_handle ,
561
+ existing_value ,
562
+ )
563
+ to_remove .append (existing_debug_handle )
555
564
556
- # Clear original one and populate with merged keys (value will point to the same object)
557
- intermediate_outputs .clear ()
558
- for start , end , val in merged_intermediate_outputs :
559
- intermediate_outputs [tuple (range (start , end + 1 ))] = val
565
+ # remove all the keys that overlap with the current key
566
+ for debug_handle in to_remove :
567
+ merged .pop (debug_handle )
568
+
569
+ # add the current key to the merged one
570
+ new_debug_handle = tuple (sorted (debug_handle_set ))
571
+ merged [new_debug_handle ] = last_value
572
+
573
+ # Sort the merged debug handles in ascending order based on their last element
574
+ # TODO: Consider adding more logic to align the order with the execution order
575
+ return dict (sorted (merged .items (), key = lambda item : item [0 ][- 1 ]))
560
576
561
577
562
578
def _debug_handles_have_overlap (
563
- aot_debug_hanlde : DebugHandle , runtime_debug_handle : DebugHandle
579
+ debug_handle : DebugHandle , target_debug_handle : DebugHandle
564
580
) -> bool :
565
581
"""
566
- Check if the AOT debug handle and the runtime debug handle have any overlap.
582
+ Check if the debug handle and the target runtime debug handle have any overlap.
567
583
"""
568
- aot_set = set (aot_debug_hanlde )
569
- runtime_set = set (runtime_debug_handle )
584
+ aot_set = set (debug_handle )
585
+ runtime_set = set (target_debug_handle )
570
586
return len (aot_set .intersection (runtime_set )) > 0
571
587
572
588
573
- def _combine_debug_hanldes (debug_handles : List [DebugHandle ]) -> DebugHandle :
589
+ def _combine_debug_handles (debug_handles : List [DebugHandle ]) -> DebugHandle :
574
590
"""Combine multiple debug handles into one debug handle"""
575
591
combined_debug_handles_set = set ()
576
592
for debug_handle in debug_handles :
@@ -584,7 +600,7 @@ def _combine_overlapped_intermediate_outputs(
584
600
"""Combine multiple overlapped intermediate outputs into one with combined debug_handles and last output"""
585
601
debug_handles = [debug_handle for debug_handle , _ in nodes ]
586
602
outputs = [output for _ , output in nodes ]
587
- combined_debug_handle = _combine_debug_hanldes (debug_handles )
603
+ combined_debug_handle = _combine_debug_handles (debug_handles )
588
604
output = outputs [- 1 ] # Pick the last one
589
605
return combined_debug_handle , output
590
606
@@ -673,8 +689,10 @@ def map_runtime_aot_intermediate_outputs(
673
689
from runtime intermediate output to AOT intermediate output
674
690
"""
675
691
# Merge overlapping debug handles
676
- merge_overlapping_debug_handles (aot_intermediate_outputs )
677
- merge_overlapping_debug_handles (runtime_intermediate_outputs )
692
+ aot_intermediate_outputs = merge_overlapping_debug_handles (aot_intermediate_outputs )
693
+ runtime_intermediate_outputs = merge_overlapping_debug_handles (
694
+ runtime_intermediate_outputs
695
+ )
678
696
679
697
# Create a graph(nodes and edges) of overlapping(between aot and runtime) debug handles
680
698
nodes , edges = _create_debug_handle_overlap_graph (
0 commit comments