@@ -554,6 +554,7 @@ def _merge_runtime_debug_handles(
554554 Merge two DebugHandles by removing elements from debug_handle1 that are also present in debug_handle2,
555555 while preserving the relative order of elements in both modified debug_handle1 and debug_handle2.
556556 All elements from the modified debug_handle1 will appear before any elements from debug_handle2.
557+ Also removes duplicates within debug_handle2.
557558 """
558559
559560 # Initialize a list to store unique elements in order
@@ -566,14 +567,16 @@ def _merge_runtime_debug_handles(
566567 # If the element has not been seen before, add it to the list and mark it as seen
567568 if item not in seen :
568569 unique_ordered_list .append (item )
569-
570+ seen = set ( unique_ordered_list )
570571 for item in debug_handle2 :
571- unique_ordered_list .append (item )
572+ if item not in seen :
573+ unique_ordered_list .append (item )
574+ seen .add (item )
572575 return tuple (unique_ordered_list )
573576
574577
575578def merge_runtime_overlapping_debug_handles (
576- intermediate_outputs : Dict [DebugHandle , Tuple [int , Any ]]
579+ runtime_intermediate_outputs : Dict [DebugHandle , Tuple [int , Any ]]
577580) -> Dict [DebugHandle , Tuple [int , Any ]]:
578581 """
579582 Merges runtimes with overlapping debug handles into a single key in the dict.
@@ -585,15 +588,18 @@ def merge_runtime_overlapping_debug_handles(
585588
586589 The value associated with the merged key is determined by the debug handle with the highest instruction id.
587590 """
588- if len (intermediate_outputs ) == 0 :
591+ if len (runtime_intermediate_outputs ) == 0 :
589592 return {}
590593 merged : Dict [DebugHandle , Tuple [int , Any ]] = {}
591- for debug_handle , (instruction_id , debug_data ) in intermediate_outputs .items ():
594+ for debug_handle , (
595+ instruction_id ,
596+ debug_data ,
597+ ) in runtime_intermediate_outputs .items ():
592598 curr_debug_handle , last_value = debug_handle , (instruction_id , debug_data )
593599 # Collect any existing keys that overlap with the current key
594600 to_remove = []
595601 for existing_debug_handle , existing_value in merged .items ():
596- if any ( item in existing_debug_handle for item in debug_handle ):
602+ if set ( debug_handle ) & set ( existing_debug_handle ):
597603 # Keep the value with the highest instruction_id
598604 # Also merge the debug handles higher instruction_id
599605 if existing_value [0 ] < instruction_id :
@@ -759,7 +765,11 @@ def map_runtime_aot_intermediate_outputs(
759765 # The size of runtime_list should be 1 because all AOT debug_handles are tuples with one element.
760766 # Additionally, runtime debug handles have already undergone pre-processing to merge overlapping debug_hanldes.
761767 # As a result, there shouldn't be any 1-to-n or n-to-n (AOT to runtime) mappings.
762- assert len (runtime_list ) == 1
768+ if len (runtime_list ) != 1 :
769+ raise ValueError (
770+ f"Expected only one runtime debug handle, but found { len (runtime_list )} : { runtime_list } "
771+ )
772+
763773 runtime_debug_handle , runtime_intermediate_output = runtime_list [0 ]
764774
765775 # Combine aot debug handles into a single key
0 commit comments