Skip to content

Commit e384b24

Browse files
Juntian Liufacebook-github-bot
authored andcommitted
Correct debug handle merging logic
Summary: This update refines the merge_overlapping_debug_handles function to handle non-continuous debug handle tuples effectively. Previously, the function assumed that debug handle tuples were continuous. The new implementation: 1.Merges overlapping debug handles by taking the union of their elements, allowing for non-continuous tuples. 2.Selects the value for the merged key based on the debug handle with the highest last element. Reviewed By: GregoryComer Differential Revision: D77461675
1 parent b74c68d commit e384b24

File tree

3 files changed

+75
-37
lines changed

3 files changed

+75
-37
lines changed

devtools/inspector/_inspector_utils.py

Lines changed: 46 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -530,47 +530,61 @@ def compare_results(
530530
return results
531531

532532

533-
def merge_overlapping_debug_handles(intermediate_outputs: Dict[DebugHandle, Any]):
533+
def merge_overlapping_debug_handles(intermediate_outputs: Dict[DebugHandle, Any]) -> Dict[DebugHandle, Any]:
534534
"""
535-
Merge overlapping debug handles int a single key
535+
Merges overlapping debug handles into a single key in the dict.
536+
For each debug handle, this function checks for overlaps with existing keys in the merged dict.
537+
If overlaps are found, it combines the overlapping keys into a single key by taking the union of their elements.
538+
The value associated with the merged key is determined by the debug handle with the highest last element.
536539
"""
540+
537541
if len(intermediate_outputs) == 0:
538-
return
539-
# Extract and normalize into (start, end, val)
540-
intervals = [(min(key), max(key), val) for key, val in intermediate_outputs.items()]
541-
intervals.sort(key=lambda x: x[0])
542-
543-
# Merge overlapping debug_hanldes, picking the last value
544-
merged_intermediate_outputs = []
545-
cur_start, cur_end, cur_val = intervals[0]
546-
for start, end, val in intervals[1:]:
547-
if start <= cur_end: # Overlaps
548-
if end > cur_end: # Extend if this one goes further
549-
cur_end, cur_val = end, val
542+
return {}
550543

551-
else:
552-
merged_intermediate_outputs.append((cur_start, cur_end, cur_val))
553-
cur_start, cur_end, cur_val = start, end, val
554-
merged_intermediate_outputs.append((cur_start, cur_end, cur_val))
544+
merged: Dict[DebugHandle, Any] = {}
545+
546+
for debug_handle, value in intermediate_outputs.items():
547+
debug_handle_set = set(debug_handle)
548+
curr_debug_handle, last_value = debug_handle, value
549+
550+
# collect any existing keys that overlap with the current key
551+
to_remove = []
552+
for existing_debug_handle, existing_value in merged.items():
553+
if debug_handle_set.intersection(set(existing_debug_handle)):
554+
# abosrb their ints
555+
debug_handle_set |= set(existing_debug_handle)
556+
if existing_debug_handle[-1] > curr_debug_handle[-1]:
557+
curr_debug_handle, last_value = (
558+
existing_debug_handle,
559+
existing_value,
560+
)
561+
to_remove.append(existing_debug_handle)
555562

556-
# Clear original one and populate with merged keys (value will point to the same object)
557-
intermediate_outputs.clear()
558-
for start, end, val in merged_intermediate_outputs:
559-
intermediate_outputs[tuple(range(start, end + 1))] = val
563+
# remove all the keys that overlap with the current key
564+
for debug_handle in to_remove:
565+
merged.pop(debug_handle)
566+
567+
# add the current key to the merged one
568+
new_debug_handle = tuple(sorted(debug_handle_set))
569+
merged[new_debug_handle] = last_value
570+
571+
# Sort the merged debug handles in ascending order based on their last element
572+
# TODO: Consider adding more logic to align the order with the execution order
573+
return dict(sorted(merged.items(), key=lambda item: item[0][-1]))
560574

561575

562576
def _debug_handles_have_overlap(
563-
aot_debug_hanlde: DebugHandle, runtime_debug_handle: DebugHandle
577+
debug_handle: DebugHandle, target_debug_handle: DebugHandle
564578
) -> bool:
565579
"""
566-
Check if the AOT debug handle and the runtime debug handle have any overlap.
580+
Check if the debug handle and the target runtime debug handle have any overlap.
567581
"""
568-
aot_set = set(aot_debug_hanlde)
569-
runtime_set = set(runtime_debug_handle)
582+
aot_set = set(debug_handle)
583+
runtime_set = set(target_debug_handle)
570584
return len(aot_set.intersection(runtime_set)) > 0
571585

572586

573-
def _combine_debug_hanldes(debug_handles: List[DebugHandle]) -> DebugHandle:
587+
def _combine_debug_handles(debug_handles: List[DebugHandle]) -> DebugHandle:
574588
"""Combine multiple debug handles into one debug handle"""
575589
combined_debug_handles_set = set()
576590
for debug_handle in debug_handles:
@@ -584,7 +598,7 @@ def _combine_overlapped_intermediate_outputs(
584598
"""Combine multiple overlapped intermediate outputs into one with combined debug_handles and last output"""
585599
debug_handles = [debug_handle for debug_handle, _ in nodes]
586600
outputs = [output for _, output in nodes]
587-
combined_debug_handle = _combine_debug_hanldes(debug_handles)
601+
combined_debug_handle = _combine_debug_handles(debug_handles)
588602
output = outputs[-1] # Pick the last one
589603
return combined_debug_handle, output
590604

@@ -673,8 +687,10 @@ def map_runtime_aot_intermediate_outputs(
673687
from runtime intermediate output to AOT intermediate output
674688
"""
675689
# Merge overlapping debug handles
676-
merge_overlapping_debug_handles(aot_intermediate_outputs)
677-
merge_overlapping_debug_handles(runtime_intermediate_outputs)
690+
aot_intermediate_outputs = merge_overlapping_debug_handles(aot_intermediate_outputs)
691+
runtime_intermediate_outputs = merge_overlapping_debug_handles(
692+
runtime_intermediate_outputs
693+
)
678694

679695
# Create a graph(nodes and edges) of overlapping(between aot and runtime) debug handles
680696
nodes, edges = _create_debug_handle_overlap_graph(

devtools/inspector/tests/inspector_test.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -665,7 +665,7 @@ def _gen_random_events(self) -> List[Event]:
665665
events = []
666666
for i in range(2):
667667
events.append(
668-
# OPERATOR_CALL with debug_hanldes/instruction_id 0 and 2
668+
# OPERATOR_CALL with debug_handle/instruction_id 0 and 2
669669
Event(
670670
name="OPERATOR_CALL",
671671
op_types=[OP_TYPE],
@@ -676,7 +676,7 @@ def _gen_random_events(self) -> List[Event]:
676676
)
677677
)
678678
events.append(
679-
# op_0/op_1 wiht empty op_types and with debug_hanldes/instruction_id 1 and 3
679+
# op_0/op_1 wiht empty op_types and with debug_handle/instruction_id 1 and 3
680680
Event(
681681
name=f"op_{i}",
682682
op_types=[],
@@ -687,7 +687,7 @@ def _gen_random_events(self) -> List[Event]:
687687
)
688688
)
689689

690-
# op_2 with debug_hanldes/instruction_id 4
690+
# op_2 with debug_handle/instruction_id 4
691691
events.append(
692692
Event(
693693
name="op_2",
@@ -698,7 +698,7 @@ def _gen_random_events(self) -> List[Event]:
698698
_instruction_id=4,
699699
)
700700
)
701-
# op_3 also with debug_hanldes 4 but with instruction_id 5
701+
# op_3 also with debug_handle 4 but with instruction_id 5
702702
events.append(
703703
Event(
704704
name="op_3",
@@ -710,7 +710,7 @@ def _gen_random_events(self) -> List[Event]:
710710
)
711711
)
712712

713-
# op_4 to op_7 with debug_hanldes 5 to 8 and instruction_id 6 to 9
713+
# op_4 to op_7 with debug_handle 5 to 8 and instruction_id 6 to 9
714714
for i in range(4, EVENTS_SIZE - 2):
715715
events.append(
716716
Event(

devtools/inspector/tests/inspector_utils_test.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ def test_compare_results_uint8(self):
223223
self.assertGreater(calculate_snr([a], [b])[0], 30.0)
224224
self.assertAlmostEqual(calculate_cosine_similarity([a], [b])[0], 1.0)
225225

226-
def test_merge_overlapping_debug_handles(self):
226+
def test_merge_overlapping_debug_handles_basic(self):
227227
big_tensor = torch.rand(100, 100)
228228
intermediate_outputs = {
229229
(1, 2, 3): "val1",
@@ -233,7 +233,7 @@ def test_merge_overlapping_debug_handles(self):
233233
(11, 12): big_tensor,
234234
}
235235
# basic merge behavior
236-
merge_overlapping_debug_handles(intermediate_outputs)
236+
intermediate_outputs = merge_overlapping_debug_handles(intermediate_outputs)
237237
expected_intermediate_outputs = {
238238
(1, 2, 3, 4, 5): "val2",
239239
(6, 7, 8): "val3",
@@ -243,6 +243,28 @@ def test_merge_overlapping_debug_handles(self):
243243
self.assertEqual(intermediate_outputs, expected_intermediate_outputs)
244244
self.assertIs(expected_intermediate_outputs[(10, 11, 12)], big_tensor)
245245

246+
def test_merge_overlapping_debug_handles_non_continuous(self):
247+
tensor1 = (torch.randn(3, 4),)
248+
tensor2 = (torch.randn(2, 3),)
249+
tensor3 = (torch.randn(4, 5),)
250+
tensor4 = (torch.randn(6, 7),)
251+
tensor5 = (torch.randn(8, 9),)
252+
intermediate_outputs = {
253+
(1, 10): tensor1,
254+
(2, 5): tensor2,
255+
(1, 7, 9): tensor3,
256+
(11, 13): tensor4,
257+
(11, 15): tensor5,
258+
}
259+
intermediate_outputs = merge_overlapping_debug_handles(intermediate_outputs)
260+
expected_intermediate_outputs = {
261+
(2, 5): tensor2,
262+
(1, 7, 9, 10): tensor1,
263+
(11, 13, 15): tensor5,
264+
}
265+
266+
self.assertEqual(intermediate_outputs, expected_intermediate_outputs)
267+
246268
def test_map_runtime_aot_intermediate_outputs_empty_inputs(self):
247269
# When the inputs are empty, the output should also be empty
248270
aot_intermediate_outputs = {}

0 commit comments

Comments
 (0)