Skip to content

Commit 292c7b4

Browse files
author
Juntian Liu
authored
Correct debug handle merging logic
Differential Revision: D77461675 Pull Request resolved: #12073
1 parent 13e2668 commit 292c7b4

File tree

3 files changed

+77
-37
lines changed

3 files changed

+77
-37
lines changed

devtools/inspector/_inspector_utils.py

Lines changed: 48 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -530,47 +530,63 @@ def compare_results(
530530
return results
531531

532532

533-
def merge_overlapping_debug_handles(intermediate_outputs: Dict[DebugHandle, Any]):
533+
def merge_overlapping_debug_handles(
534+
intermediate_outputs: Dict[DebugHandle, Any]
535+
) -> Dict[DebugHandle, Any]:
534536
"""
535-
Merge overlapping debug handles int a single key
537+
Merges overlapping debug handles into a single key in the dict.
538+
For each debug handle, this function checks for overlaps with existing keys in the merged dict.
539+
If overlaps are found, it combines the overlapping keys into a single key by taking the union of their elements.
540+
The value associated with the merged key is determined by the debug handle with the highest last element.
536541
"""
542+
537543
if len(intermediate_outputs) == 0:
538-
return
539-
# Extract and normalize into (start, end, val)
540-
intervals = [(min(key), max(key), val) for key, val in intermediate_outputs.items()]
541-
intervals.sort(key=lambda x: x[0])
542-
543-
# Merge overlapping debug_hanldes, picking the last value
544-
merged_intermediate_outputs = []
545-
cur_start, cur_end, cur_val = intervals[0]
546-
for start, end, val in intervals[1:]:
547-
if start <= cur_end: # Overlaps
548-
if end > cur_end: # Extend if this one goes further
549-
cur_end, cur_val = end, val
544+
return {}
550545

551-
else:
552-
merged_intermediate_outputs.append((cur_start, cur_end, cur_val))
553-
cur_start, cur_end, cur_val = start, end, val
554-
merged_intermediate_outputs.append((cur_start, cur_end, cur_val))
546+
merged: Dict[DebugHandle, Any] = {}
547+
548+
for debug_handle, value in intermediate_outputs.items():
549+
debug_handle_set = set(debug_handle)
550+
curr_debug_handle, last_value = debug_handle, value
551+
552+
# collect any existing keys that overlap with the current key
553+
to_remove = []
554+
for existing_debug_handle, existing_value in merged.items():
555+
if debug_handle_set.intersection(set(existing_debug_handle)):
556+
# abosrb their ints
557+
debug_handle_set |= set(existing_debug_handle)
558+
if existing_debug_handle[-1] > curr_debug_handle[-1]:
559+
curr_debug_handle, last_value = (
560+
existing_debug_handle,
561+
existing_value,
562+
)
563+
to_remove.append(existing_debug_handle)
555564

556-
# Clear original one and populate with merged keys (value will point to the same object)
557-
intermediate_outputs.clear()
558-
for start, end, val in merged_intermediate_outputs:
559-
intermediate_outputs[tuple(range(start, end + 1))] = val
565+
# remove all the keys that overlap with the current key
566+
for debug_handle in to_remove:
567+
merged.pop(debug_handle)
568+
569+
# add the current key to the merged one
570+
new_debug_handle = tuple(sorted(debug_handle_set))
571+
merged[new_debug_handle] = last_value
572+
573+
# Sort the merged debug handles in ascending order based on their last element
574+
# TODO: Consider adding more logic to align the order with the execution order
575+
return dict(sorted(merged.items(), key=lambda item: item[0][-1]))
560576

561577

562578
def _debug_handles_have_overlap(
563-
aot_debug_hanlde: DebugHandle, runtime_debug_handle: DebugHandle
579+
debug_handle: DebugHandle, target_debug_handle: DebugHandle
564580
) -> bool:
565581
"""
566-
Check if the AOT debug handle and the runtime debug handle have any overlap.
582+
Check if the debug handle and the target runtime debug handle have any overlap.
567583
"""
568-
aot_set = set(aot_debug_hanlde)
569-
runtime_set = set(runtime_debug_handle)
584+
aot_set = set(debug_handle)
585+
runtime_set = set(target_debug_handle)
570586
return len(aot_set.intersection(runtime_set)) > 0
571587

572588

573-
def _combine_debug_hanldes(debug_handles: List[DebugHandle]) -> DebugHandle:
589+
def _combine_debug_handles(debug_handles: List[DebugHandle]) -> DebugHandle:
574590
"""Combine multiple debug handles into one debug handle"""
575591
combined_debug_handles_set = set()
576592
for debug_handle in debug_handles:
@@ -584,7 +600,7 @@ def _combine_overlapped_intermediate_outputs(
584600
"""Combine multiple overlapped intermediate outputs into one with combined debug_handles and last output"""
585601
debug_handles = [debug_handle for debug_handle, _ in nodes]
586602
outputs = [output for _, output in nodes]
587-
combined_debug_handle = _combine_debug_hanldes(debug_handles)
603+
combined_debug_handle = _combine_debug_handles(debug_handles)
588604
output = outputs[-1] # Pick the last one
589605
return combined_debug_handle, output
590606

@@ -673,8 +689,10 @@ def map_runtime_aot_intermediate_outputs(
673689
from runtime intermediate output to AOT intermediate output
674690
"""
675691
# Merge overlapping debug handles
676-
merge_overlapping_debug_handles(aot_intermediate_outputs)
677-
merge_overlapping_debug_handles(runtime_intermediate_outputs)
692+
aot_intermediate_outputs = merge_overlapping_debug_handles(aot_intermediate_outputs)
693+
runtime_intermediate_outputs = merge_overlapping_debug_handles(
694+
runtime_intermediate_outputs
695+
)
678696

679697
# Create a graph(nodes and edges) of overlapping(between aot and runtime) debug handles
680698
nodes, edges = _create_debug_handle_overlap_graph(

devtools/inspector/tests/inspector_test.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -665,7 +665,7 @@ def _gen_random_events(self) -> List[Event]:
665665
events = []
666666
for i in range(2):
667667
events.append(
668-
# OPERATOR_CALL with debug_hanldes/instruction_id 0 and 2
668+
# OPERATOR_CALL with debug_handle/instruction_id 0 and 2
669669
Event(
670670
name="OPERATOR_CALL",
671671
op_types=[OP_TYPE],
@@ -676,7 +676,7 @@ def _gen_random_events(self) -> List[Event]:
676676
)
677677
)
678678
events.append(
679-
# op_0/op_1 wiht empty op_types and with debug_hanldes/instruction_id 1 and 3
679+
# op_0/op_1 wiht empty op_types and with debug_handle/instruction_id 1 and 3
680680
Event(
681681
name=f"op_{i}",
682682
op_types=[],
@@ -687,7 +687,7 @@ def _gen_random_events(self) -> List[Event]:
687687
)
688688
)
689689

690-
# op_2 with debug_hanldes/instruction_id 4
690+
# op_2 with debug_handle/instruction_id 4
691691
events.append(
692692
Event(
693693
name="op_2",
@@ -698,7 +698,7 @@ def _gen_random_events(self) -> List[Event]:
698698
_instruction_id=4,
699699
)
700700
)
701-
# op_3 also with debug_hanldes 4 but with instruction_id 5
701+
# op_3 also with debug_handle 4 but with instruction_id 5
702702
events.append(
703703
Event(
704704
name="op_3",
@@ -710,7 +710,7 @@ def _gen_random_events(self) -> List[Event]:
710710
)
711711
)
712712

713-
# op_4 to op_7 with debug_hanldes 5 to 8 and instruction_id 6 to 9
713+
# op_4 to op_7 with debug_handle 5 to 8 and instruction_id 6 to 9
714714
for i in range(4, EVENTS_SIZE - 2):
715715
events.append(
716716
Event(

devtools/inspector/tests/inspector_utils_test.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ def test_compare_results_uint8(self):
223223
self.assertGreater(calculate_snr([a], [b])[0], 30.0)
224224
self.assertAlmostEqual(calculate_cosine_similarity([a], [b])[0], 1.0)
225225

226-
def test_merge_overlapping_debug_handles(self):
226+
def test_merge_overlapping_debug_handles_basic(self):
227227
big_tensor = torch.rand(100, 100)
228228
intermediate_outputs = {
229229
(1, 2, 3): "val1",
@@ -233,7 +233,7 @@ def test_merge_overlapping_debug_handles(self):
233233
(11, 12): big_tensor,
234234
}
235235
# basic merge behavior
236-
merge_overlapping_debug_handles(intermediate_outputs)
236+
intermediate_outputs = merge_overlapping_debug_handles(intermediate_outputs)
237237
expected_intermediate_outputs = {
238238
(1, 2, 3, 4, 5): "val2",
239239
(6, 7, 8): "val3",
@@ -243,6 +243,28 @@ def test_merge_overlapping_debug_handles(self):
243243
self.assertEqual(intermediate_outputs, expected_intermediate_outputs)
244244
self.assertIs(expected_intermediate_outputs[(10, 11, 12)], big_tensor)
245245

246+
def test_merge_overlapping_debug_handles_non_continuous(self):
247+
tensor1 = (torch.randn(3, 4),)
248+
tensor2 = (torch.randn(2, 3),)
249+
tensor3 = (torch.randn(4, 5),)
250+
tensor4 = (torch.randn(6, 7),)
251+
tensor5 = (torch.randn(8, 9),)
252+
intermediate_outputs = {
253+
(1, 10): tensor1,
254+
(2, 5): tensor2,
255+
(1, 7, 9): tensor3,
256+
(11, 13): tensor4,
257+
(11, 15): tensor5,
258+
}
259+
intermediate_outputs = merge_overlapping_debug_handles(intermediate_outputs)
260+
expected_intermediate_outputs = {
261+
(2, 5): tensor2,
262+
(1, 7, 9, 10): tensor1,
263+
(11, 13, 15): tensor5,
264+
}
265+
266+
self.assertEqual(intermediate_outputs, expected_intermediate_outputs)
267+
246268
def test_map_runtime_aot_intermediate_outputs_empty_inputs(self):
247269
# When the inputs are empty, the output should also be empty
248270
aot_intermediate_outputs = {}

0 commit comments

Comments
 (0)