Skip to content

Commit f63e298

Browse files
author
Juntian Liu
authored
Updated pre-processing runtime merging logic
Differential Revision: D77905958 Pull Request resolved: #12302
1 parent 6ac5df2 commit f63e298

File tree

3 files changed

+90
-65
lines changed

3 files changed

+90
-65
lines changed

devtools/inspector/_inspector.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
is_debug_output,
6161
is_inference_output_equal,
6262
map_runtime_aot_intermediate_outputs,
63+
merge_runtime_overlapping_debug_handles,
6364
ProgramOutput,
6465
RESERVED_FRAMEWORK_EVENT_NAMES,
6566
TimeScale,
@@ -1208,6 +1209,8 @@ def _get_runtime_intermediate_outputs_and_op_names(
12081209
event.debug_data,
12091210
)
12101211
debug_handle_to_op_name[debug_handle] = event.name
1212+
1213+
merge_runtime_overlapping_debug_handles(debug_handle_to_output)
12111214
return {
12121215
k: v[1] for k, v in debug_handle_to_output.items()
12131216
}, debug_handle_to_op_name
@@ -1387,7 +1390,7 @@ def calculate_numeric_gap(self, distance: str = "MSE") -> pd.DataFrame:
13871390
)
13881391
if len(aot_intermediate_outputs) == 0 or len(aot_debug_handle_to_op_name) == 0:
13891392
raise ValueError(
1390-
"calculate_numerical_gap error: The aot debug information is required but not populated"
1393+
"Missing etrecord or missing representative inputs within etrecord, both of which are required for calculating numerical gap"
13911394
)
13921395
# The runtime_op_names will be used later to map runtime debug_handle to op_name
13931396
runtime_intermediate_outputs, runtime_debug_handle_to_op_name = (

devtools/inspector/_inspector_utils.py

Lines changed: 53 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -538,49 +538,71 @@ def compare_results(
538538
return results
539539

540540

541-
def merge_overlapping_debug_handles(
542-
intermediate_outputs: Dict[DebugHandle, Any]
543-
) -> Dict[DebugHandle, Any]:
541+
def _merge_runtime_debug_handles(
542+
debug_handle1: DebugHandle, debug_handle2: DebugHandle
543+
) -> DebugHandle:
544544
"""
545-
Merges overlapping debug handles into a single key in the dict.
546-
For each debug handle, this function checks for overlaps with existing keys in the merged dict.
547-
If overlaps are found, it combines the overlapping keys into a single key by taking the union of their elements.
548-
The value associated with the merged key is determined by the debug handle with the highest last element.
545+
Merge two DebugHandles by removing elements from debug_handle1 that are also present in debug_handle2,
546+
while preserving the relative order of elements in both modified debug_handle1 and debug_handle2.
547+
All elements from the modified debug_handle1 will appear before any elements from debug_handle2.
549548
"""
550549

551-
if len(intermediate_outputs) == 0:
552-
return {}
550+
# Initialize a list to store unique elements in order
551+
unique_ordered_list = []
552+
553+
# Initialize a set to track elements that have already been seen
554+
seen = set(debug_handle2)
555+
556+
for item in debug_handle1:
557+
# If the element has not been seen before, add it to the list and mark it as seen
558+
if item not in seen:
559+
unique_ordered_list.append(item)
553560

554-
merged: Dict[DebugHandle, Any] = {}
561+
for item in debug_handle2:
562+
unique_ordered_list.append(item)
563+
return tuple(unique_ordered_list)
555564

556-
for debug_handle, value in intermediate_outputs.items():
557-
debug_handle_set = set(debug_handle)
558-
curr_debug_handle, last_value = debug_handle, value
559565

560-
# collect any existing keys that overlap with the current key
566+
def merge_runtime_overlapping_debug_handles(
567+
intermediate_outputs: Dict[DebugHandle, Tuple[int, Any]]
568+
) -> Dict[DebugHandle, Tuple[int, Any]]:
569+
"""
570+
Merges runtimes with overlapping debug handles into a single key in the dict.
571+
572+
For each debug handle, this function checks for overlaps with existing keys.
573+
If overlaps are found, it combines the overlapping keys into a single key by taking
574+
the union of their elements while maintaining the order. The order is preserved such that
575+
higher instruction_id appears after the debug_handle with lower instruction_id.
576+
577+
The value associated with the merged key is determined by the debug handle with the highest instruction id.
578+
"""
579+
if len(intermediate_outputs) == 0:
580+
return {}
581+
merged: Dict[DebugHandle, Tuple[int, Any]] = {}
582+
for debug_handle, (instruction_id, debug_data) in intermediate_outputs.items():
583+
curr_debug_handle, last_value = debug_handle, (instruction_id, debug_data)
584+
# Collect any existing keys that overlap with the current key
561585
to_remove = []
562586
for existing_debug_handle, existing_value in merged.items():
563-
if debug_handle_set.intersection(set(existing_debug_handle)):
564-
# abosrb their ints
565-
debug_handle_set |= set(existing_debug_handle)
566-
if existing_debug_handle[-1] > curr_debug_handle[-1]:
567-
curr_debug_handle, last_value = (
568-
existing_debug_handle,
569-
existing_value,
587+
if any(item in existing_debug_handle for item in debug_handle):
588+
# Keep the value with the highest instruction_id
589+
# Also merge the debug handles higher instruction_id
590+
if existing_value[0] < instruction_id:
591+
curr_debug_handle = _merge_runtime_debug_handles(
592+
existing_debug_handle, curr_debug_handle
570593
)
594+
else:
595+
curr_debug_handle = _merge_runtime_debug_handles(
596+
curr_debug_handle, existing_debug_handle
597+
)
598+
last_value = existing_value
571599
to_remove.append(existing_debug_handle)
572-
573-
# remove all the keys that overlap with the current key
600+
# Remove all the keys that overlap with the current key
574601
for debug_handle in to_remove:
575602
merged.pop(debug_handle)
576-
577-
# add the current key to the merged one
578-
new_debug_handle = tuple(sorted(debug_handle_set))
579-
merged[new_debug_handle] = last_value
580-
581-
# Sort the merged debug handles in ascending order based on their last element
582-
# TODO: Consider adding more logic to align the order with the execution order
583-
return dict(sorted(merged.items(), key=lambda item: item[0][-1]))
603+
# Add the current key to the merged one
604+
merged[curr_debug_handle] = last_value
605+
return merged
584606

585607

586608
def _debug_handles_have_overlap(
@@ -696,12 +718,6 @@ def map_runtime_aot_intermediate_outputs(
696718
Dict[Tuple[DebugHandle, Any], Tuple[DebugHandle, Any]] - Mapping
697719
from runtime intermediate output to AOT intermediate output
698720
"""
699-
# Merge overlapping debug handles
700-
aot_intermediate_outputs = merge_overlapping_debug_handles(aot_intermediate_outputs)
701-
runtime_intermediate_outputs = merge_overlapping_debug_handles(
702-
runtime_intermediate_outputs
703-
)
704-
705721
# Create a graph(nodes and edges) of overlapping(between aot and runtime) debug handles
706722
nodes, edges = _create_debug_handle_overlap_graph(
707723
aot_intermediate_outputs, runtime_intermediate_outputs

devtools/inspector/tests/inspector_utils_test.py

Lines changed: 33 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
get_aot_debug_handle_to_op_name_mapping,
4040
is_inference_output_equal,
4141
map_runtime_aot_intermediate_outputs,
42-
merge_overlapping_debug_handles,
42+
merge_runtime_overlapping_debug_handles,
4343
NodeFilter,
4444
TimeScale,
4545
)
@@ -228,44 +228,50 @@ def test_compare_results_uint8(self):
228228
def test_merge_overlapping_debug_handles_basic(self):
229229
big_tensor = torch.rand(100, 100)
230230
intermediate_outputs = {
231-
(1, 2, 3): "val1",
232-
(2, 3, 4, 5): "val2",
233-
(6, 7, 8): "val3",
234-
(10, 11): "val4",
235-
(11, 12): big_tensor,
231+
(1, 2, 3): (1, "val1"),
232+
(2, 3, 4, 5): (2, "val2"),
233+
(6, 7, 8): (3, "val3"),
234+
(10, 11): (4, "val4"),
235+
(11, 12): (5, big_tensor),
236236
}
237237
# basic merge behavior
238-
intermediate_outputs = merge_overlapping_debug_handles(intermediate_outputs)
238+
intermediate_outputs = merge_runtime_overlapping_debug_handles(
239+
intermediate_outputs
240+
)
239241
expected_intermediate_outputs = {
240-
(1, 2, 3, 4, 5): "val2",
241-
(6, 7, 8): "val3",
242-
(10, 11, 12): big_tensor,
242+
(1, 2, 3, 4, 5): (2, "val2"),
243+
(6, 7, 8): (3, "val3"),
244+
(10, 11, 12): (5, big_tensor),
243245
}
244-
245246
self.assertEqual(intermediate_outputs, expected_intermediate_outputs)
246-
self.assertIs(expected_intermediate_outputs[(10, 11, 12)], big_tensor)
247+
self.assertIs(expected_intermediate_outputs[(10, 11, 12)][1], big_tensor)
247248

248249
def test_merge_overlapping_debug_handles_non_continuous(self):
249-
tensor1 = (torch.randn(3, 4),)
250-
tensor2 = (torch.randn(2, 3),)
251-
tensor3 = (torch.randn(4, 5),)
252-
tensor4 = (torch.randn(6, 7),)
253-
tensor5 = (torch.randn(8, 9),)
250+
tensor1 = torch.randn(3, 4)
251+
tensor2 = torch.randn(2, 3)
252+
tensor3 = torch.randn(4, 5)
253+
tensor4 = torch.randn(6, 7)
254+
tensor5 = torch.randn(8, 9)
254255
intermediate_outputs = {
255-
(1, 10): tensor1,
256-
(2, 5): tensor2,
257-
(1, 7, 9): tensor3,
258-
(11, 13): tensor4,
259-
(11, 15): tensor5,
256+
(1, 10): (1, tensor1),
257+
(2, 5): (2, tensor2),
258+
(1, 7, 9): (3, tensor3),
259+
(11, 13): (4, tensor4),
260+
(11, 15): (5, tensor5),
260261
}
261-
intermediate_outputs = merge_overlapping_debug_handles(intermediate_outputs)
262+
intermediate_outputs = merge_runtime_overlapping_debug_handles(
263+
intermediate_outputs
264+
)
262265
expected_intermediate_outputs = {
263-
(2, 5): tensor2,
264-
(1, 7, 9, 10): tensor1,
265-
(11, 13, 15): tensor5,
266+
(2, 5): (2, tensor2),
267+
(10, 1, 7, 9): (3, tensor3),
268+
(13, 11, 15): (5, tensor5),
266269
}
267270

268-
self.assertEqual(intermediate_outputs, expected_intermediate_outputs)
271+
for key in expected_intermediate_outputs:
272+
expected_value = expected_intermediate_outputs[key][1]
273+
actual_value = intermediate_outputs[key][1]
274+
self.assertTrue(torch.allclose(expected_value, actual_value))
269275

270276
def test_map_runtime_aot_intermediate_outputs_empty_inputs(self):
271277
# When the inputs are empty, the output should also be empty

0 commit comments

Comments
 (0)