diff --git a/devtools/inspector/_inspector_utils.py b/devtools/inspector/_inspector_utils.py index 249a2203e4c..2b5d975abed 100644 --- a/devtools/inspector/_inspector_utils.py +++ b/devtools/inspector/_inspector_utils.py @@ -530,47 +530,63 @@ def compare_results( return results -def merge_overlapping_debug_handles(intermediate_outputs: Dict[DebugHandle, Any]): +def merge_overlapping_debug_handles( + intermediate_outputs: Dict[DebugHandle, Any] +) -> Dict[DebugHandle, Any]: """ - Merge overlapping debug handles int a single key + Merges overlapping debug handles into a single key in the dict. + For each debug handle, this function checks for overlaps with existing keys in the merged dict. + If overlaps are found, it combines the overlapping keys into a single key by taking the union of their elements. + The value associated with the merged key is determined by the debug handle with the highest last element. """ + if len(intermediate_outputs) == 0: - return - # Extract and normalize into (start, end, val) - intervals = [(min(key), max(key), val) for key, val in intermediate_outputs.items()] - intervals.sort(key=lambda x: x[0]) - - # Merge overlapping debug_hanldes, picking the last value - merged_intermediate_outputs = [] - cur_start, cur_end, cur_val = intervals[0] - for start, end, val in intervals[1:]: - if start <= cur_end: # Overlaps - if end > cur_end: # Extend if this one goes further - cur_end, cur_val = end, val + return {} - else: - merged_intermediate_outputs.append((cur_start, cur_end, cur_val)) - cur_start, cur_end, cur_val = start, end, val - merged_intermediate_outputs.append((cur_start, cur_end, cur_val)) + merged: Dict[DebugHandle, Any] = {} + + for debug_handle, value in intermediate_outputs.items(): + debug_handle_set = set(debug_handle) + curr_debug_handle, last_value = debug_handle, value + + # collect any existing keys that overlap with the current key + to_remove = [] + for existing_debug_handle, existing_value in merged.items(): + if debug_handle_set.intersection(set(existing_debug_handle)): + # abosrb their ints + debug_handle_set |= set(existing_debug_handle) + if existing_debug_handle[-1] > curr_debug_handle[-1]: + curr_debug_handle, last_value = ( + existing_debug_handle, + existing_value, + ) + to_remove.append(existing_debug_handle) - # Clear original one and populate with merged keys (value will point to the same object) - intermediate_outputs.clear() - for start, end, val in merged_intermediate_outputs: - intermediate_outputs[tuple(range(start, end + 1))] = val + # remove all the keys that overlap with the current key + for debug_handle in to_remove: + merged.pop(debug_handle) + + # add the current key to the merged one + new_debug_handle = tuple(sorted(debug_handle_set)) + merged[new_debug_handle] = last_value + + # Sort the merged debug handles in ascending order based on their last element + # TODO: Consider adding more logic to align the order with the execution order + return dict(sorted(merged.items(), key=lambda item: item[0][-1])) def _debug_handles_have_overlap( - aot_debug_hanlde: DebugHandle, runtime_debug_handle: DebugHandle + debug_handle: DebugHandle, target_debug_handle: DebugHandle ) -> bool: """ - Check if the AOT debug handle and the runtime debug handle have any overlap. + Check if the debug handle and the target runtime debug handle have any overlap. """ - aot_set = set(aot_debug_hanlde) - runtime_set = set(runtime_debug_handle) + aot_set = set(debug_handle) + runtime_set = set(target_debug_handle) return len(aot_set.intersection(runtime_set)) > 0 -def _combine_debug_hanldes(debug_handles: List[DebugHandle]) -> DebugHandle: +def _combine_debug_handles(debug_handles: List[DebugHandle]) -> DebugHandle: """Combine multiple debug handles into one debug handle""" combined_debug_handles_set = set() for debug_handle in debug_handles: @@ -584,7 +600,7 @@ def _combine_overlapped_intermediate_outputs( """Combine multiple overlapped intermediate outputs into one with combined debug_handles and last output""" debug_handles = [debug_handle for debug_handle, _ in nodes] outputs = [output for _, output in nodes] - combined_debug_handle = _combine_debug_hanldes(debug_handles) + combined_debug_handle = _combine_debug_handles(debug_handles) output = outputs[-1] # Pick the last one return combined_debug_handle, output @@ -673,8 +689,10 @@ def map_runtime_aot_intermediate_outputs( from runtime intermediate output to AOT intermediate output """ # Merge overlapping debug handles - merge_overlapping_debug_handles(aot_intermediate_outputs) - merge_overlapping_debug_handles(runtime_intermediate_outputs) + aot_intermediate_outputs = merge_overlapping_debug_handles(aot_intermediate_outputs) + runtime_intermediate_outputs = merge_overlapping_debug_handles( + runtime_intermediate_outputs + ) # Create a graph(nodes and edges) of overlapping(between aot and runtime) debug handles nodes, edges = _create_debug_handle_overlap_graph( diff --git a/devtools/inspector/tests/inspector_test.py b/devtools/inspector/tests/inspector_test.py index 28e33cca863..3ff30ad7269 100644 --- a/devtools/inspector/tests/inspector_test.py +++ b/devtools/inspector/tests/inspector_test.py @@ -665,7 +665,7 @@ def _gen_random_events(self) -> List[Event]: events = [] for i in range(2): events.append( - # OPERATOR_CALL with debug_hanldes/instruction_id 0 and 2 + # OPERATOR_CALL with debug_handle/instruction_id 0 and 2 Event( name="OPERATOR_CALL", op_types=[OP_TYPE], @@ -676,7 +676,7 @@ def _gen_random_events(self) -> List[Event]: ) ) events.append( - # op_0/op_1 wiht empty op_types and with debug_hanldes/instruction_id 1 and 3 + # op_0/op_1 wiht empty op_types and with debug_handle/instruction_id 1 and 3 Event( name=f"op_{i}", op_types=[], @@ -687,7 +687,7 @@ def _gen_random_events(self) -> List[Event]: ) ) - # op_2 with debug_hanldes/instruction_id 4 + # op_2 with debug_handle/instruction_id 4 events.append( Event( name="op_2", @@ -698,7 +698,7 @@ def _gen_random_events(self) -> List[Event]: _instruction_id=4, ) ) - # op_3 also with debug_hanldes 4 but with instruction_id 5 + # op_3 also with debug_handle 4 but with instruction_id 5 events.append( Event( name="op_3", @@ -710,7 +710,7 @@ def _gen_random_events(self) -> List[Event]: ) ) - # op_4 to op_7 with debug_hanldes 5 to 8 and instruction_id 6 to 9 + # op_4 to op_7 with debug_handle 5 to 8 and instruction_id 6 to 9 for i in range(4, EVENTS_SIZE - 2): events.append( Event( diff --git a/devtools/inspector/tests/inspector_utils_test.py b/devtools/inspector/tests/inspector_utils_test.py index b540f8dccd1..da46f080049 100644 --- a/devtools/inspector/tests/inspector_utils_test.py +++ b/devtools/inspector/tests/inspector_utils_test.py @@ -223,7 +223,7 @@ def test_compare_results_uint8(self): self.assertGreater(calculate_snr([a], [b])[0], 30.0) self.assertAlmostEqual(calculate_cosine_similarity([a], [b])[0], 1.0) - def test_merge_overlapping_debug_handles(self): + def test_merge_overlapping_debug_handles_basic(self): big_tensor = torch.rand(100, 100) intermediate_outputs = { (1, 2, 3): "val1", @@ -233,7 +233,7 @@ def test_merge_overlapping_debug_handles(self): (11, 12): big_tensor, } # basic merge behavior - merge_overlapping_debug_handles(intermediate_outputs) + intermediate_outputs = merge_overlapping_debug_handles(intermediate_outputs) expected_intermediate_outputs = { (1, 2, 3, 4, 5): "val2", (6, 7, 8): "val3", @@ -243,6 +243,28 @@ def test_merge_overlapping_debug_handles(self): self.assertEqual(intermediate_outputs, expected_intermediate_outputs) self.assertIs(expected_intermediate_outputs[(10, 11, 12)], big_tensor) + def test_merge_overlapping_debug_handles_non_continuous(self): + tensor1 = (torch.randn(3, 4),) + tensor2 = (torch.randn(2, 3),) + tensor3 = (torch.randn(4, 5),) + tensor4 = (torch.randn(6, 7),) + tensor5 = (torch.randn(8, 9),) + intermediate_outputs = { + (1, 10): tensor1, + (2, 5): tensor2, + (1, 7, 9): tensor3, + (11, 13): tensor4, + (11, 15): tensor5, + } + intermediate_outputs = merge_overlapping_debug_handles(intermediate_outputs) + expected_intermediate_outputs = { + (2, 5): tensor2, + (1, 7, 9, 10): tensor1, + (11, 13, 15): tensor5, + } + + self.assertEqual(intermediate_outputs, expected_intermediate_outputs) + def test_map_runtime_aot_intermediate_outputs_empty_inputs(self): # When the inputs are empty, the output should also be empty aot_intermediate_outputs = {}