@@ -554,6 +554,7 @@ def _merge_runtime_debug_handles(
554554 Merge two DebugHandles by removing elements from debug_handle1 that are also present in debug_handle2,
555555 while preserving the relative order of elements in both modified debug_handle1 and debug_handle2.
556556 All elements from the modified debug_handle1 will appear before any elements from debug_handle2.
557+ Also removes duplicates within debug_handle2.
557558 """
558559
559560 # Initialize a list to store unique elements in order
@@ -566,14 +567,16 @@ def _merge_runtime_debug_handles(
566567 # If the element has not been seen before, add it to the list and mark it as seen
567568 if item not in seen :
568569 unique_ordered_list .append (item )
569-
570+ seen = set ( unique_ordered_list )
570571 for item in debug_handle2 :
571- unique_ordered_list .append (item )
572+ if item not in seen :
573+ unique_ordered_list .append (item )
574+ seen .add (item )
572575 return tuple (unique_ordered_list )
573576
574577
575578def merge_runtime_overlapping_debug_handles (
576- intermediate_outputs : Dict [DebugHandle , Tuple [int , Any ]]
579+ runtime_intermediate_outputs : Dict [DebugHandle , Tuple [int , Any ]]
577580) -> Dict [DebugHandle , Tuple [int , Any ]]:
578581 """
579582 Merges runtimes with overlapping debug handles into a single key in the dict.
@@ -585,15 +588,15 @@ def merge_runtime_overlapping_debug_handles(
585588
586589 The value associated with the merged key is determined by the debug handle with the highest instruction id.
587590 """
588- if len (intermediate_outputs ) == 0 :
591+ if len (runtime_intermediate_outputs ) == 0 :
589592 return {}
590593 merged : Dict [DebugHandle , Tuple [int , Any ]] = {}
591- for debug_handle , (instruction_id , debug_data ) in intermediate_outputs .items ():
594+ for debug_handle , (instruction_id , debug_data ) in runtime_intermediate_outputs .items ():
592595 curr_debug_handle , last_value = debug_handle , (instruction_id , debug_data )
593596 # Collect any existing keys that overlap with the current key
594597 to_remove = []
595598 for existing_debug_handle , existing_value in merged .items ():
596- if any ( item in existing_debug_handle for item in debug_handle ):
599+ if set ( debug_handle ) & set ( existing_debug_handle ):
597600 # Keep the value with the highest instruction_id
598601 # Also merge the debug handles higher instruction_id
599602 if existing_value [0 ] < instruction_id :
@@ -759,7 +762,11 @@ def map_runtime_aot_intermediate_outputs(
759762 # The size of runtime_list should be 1 because all AOT debug_handles are tuples with one element.
760763 # Additionally, runtime debug handles have already undergone pre-processing to merge overlapping debug_hanldes.
761764 # As a result, there shouldn't be any 1-to-n or n-to-n (AOT to runtime) mappings.
762- assert len (runtime_list ) == 1
765+ if len (runtime_list ) != 1 :
766+ raise ValueError (
767+ f"Expected only one runtime debug handle, but found { len (runtime_list )} : { runtime_list } "
768+ )
769+
763770 runtime_debug_handle , runtime_intermediate_output = runtime_list [0 ]
764771
765772 # Combine aot debug handles into a single key
0 commit comments