diff --git a/devtools/inspector/_inspector_utils.py b/devtools/inspector/_inspector_utils.py index 20f407d5a11..67fcc807752 100644 --- a/devtools/inspector/_inspector_utils.py +++ b/devtools/inspector/_inspector_utils.py @@ -8,6 +8,7 @@ import math import sys +from dataclasses import dataclass from enum import Enum from typing import Any, Dict, IO, List, Mapping, Optional, Tuple, TypeAlias, Union @@ -72,6 +73,25 @@ class TimeScale(Enum): } +class NodeSource(Enum): + AOT = 1 + RUNTIME = 2 + + +@dataclass +class NodeData: + """ + Each node in the graph is an instance of NodeData, which contains: + - source: A string indicating the origin of the node (either FROM_AOT or FROM_RUNTIME). + - debug_handle: A tuple representing the unique identifier for the output. + - output: The actual output data associated with the debug handle. + """ + + source: NodeSource + debug_handle: tuple[int] + output: Any + + def calculate_time_scale_factor( source_time_scale: TimeScale, target_time_scale: TimeScale ) -> float: @@ -489,7 +509,7 @@ def merge_overlapping_debug_handles(intermediate_outputs: Dict[Tuple[int, ...], """ Merge overlapping debug handles int a single key """ - if not intermediate_outputs: + if len(intermediate_outputs) == 0: return # Extract and normalize into (start, end, val) intervals = [(min(key), max(key), val) for key, val in intermediate_outputs.items()] @@ -512,3 +532,161 @@ def merge_overlapping_debug_handles(intermediate_outputs: Dict[Tuple[int, ...], intermediate_outputs.clear() for start, end, val in merged_intermediate_outputs: intermediate_outputs[tuple(range(start, end + 1))] = val + + +def _debug_handles_have_overlap( + aot_debug_hanlde: Tuple[int, ...], runtime_debug_handle: Tuple[int, ...] +) -> bool: + """ + Check if the AOT debug handle and the runtime debug handle have any overlap. + """ + aot_set = set(aot_debug_hanlde) + runtime_set = set(runtime_debug_handle) + return len(aot_set.intersection(runtime_set)) > 0 + + +def _combine_debug_hanldes(debug_handles: List[Tuple[int, ...]]) -> Tuple[int, ...]: + """Combine multiple debug handles into one debug handle""" + combined_debug_handles_set = set() + for debug_handle in debug_handles: + combined_debug_handles_set.update(set(debug_handle)) + return tuple(sorted(combined_debug_handles_set)) + + +def _combine_overlapped_intermediate_outputs( + nodes: List[Tuple[Tuple[int, ...], Any]] +) -> Tuple[Tuple[int, ...], Any]: + """Combine multiple overlapped intermediate outputs into one with combined debug_handles and last output""" + debug_handles = [debug_handle for debug_handle, _ in nodes] + outputs = [output for _, output in nodes] + combined_debug_handle = _combine_debug_hanldes(debug_handles) + output = outputs[-1] # Pick the last one + return combined_debug_handle, output + + +def _create_debug_handle_overlap_graph( + aot_intermediate_outputs: Dict[Tuple[int, ...], Any], + runtime_intermediate_outputs: Dict[Tuple[int, ...], Any], +) -> Tuple[List[NodeData], Dict[int, List[int]]]: + """ + Create a graph representing overlapping debug handles between AOT and runtime outputs. + + Edges in the graph are represented as a dictionary where: + - The key is the index of a node in the nodes list. + - The value is a list of indices of nodes that have overlapping debug handles with the key node. + + Returns: + - A tuple containing: + - A list of NodeData instances representing the nodes in the graph. + - A dictionary representing the edges, where each key-value pair indicates connected nodes due to overlapping debug handles. + """ + nodes = [] + for debug_handle, output in aot_intermediate_outputs.items(): + nodes.append(NodeData(NodeSource.AOT, debug_handle, output)) + for debug_handle, output in runtime_intermediate_outputs.items(): + nodes.append(NodeData(NodeSource.RUNTIME, debug_handle, output)) + + edges = {i: [] for i in range(len(nodes))} + for i in range(len(nodes)): + for j in range(i + 1, len(nodes)): + node_i = nodes[i] + node_j = nodes[j] + # Only connect nodes from different sources(aot vs runtime) that overlap + if node_i.source != node_j.source and _debug_handles_have_overlap( + node_i.debug_handle, node_j.debug_handle + ): + edges[i].append(j) + edges[j].append(i) + return (nodes, edges) + + +def _find_connected_components( + nodes: List[NodeData], edges: Dict[int, List[int]] +) -> List[List[int]]: + """ + Find groups of connected nodes in a graph using DFS. + Parameters: + - nodes: A list of nodes in the graph. + - edges: A dictionary where each key is a node index, and the value is a list + of indices of connected nodes. + Returns: + - A list of connected components, each represented as a list of node indices. + """ + visited = [False] * len(nodes) + connected_components = [] + + def dfs(node_id, component): + visited[node_id] = True + component.append(node_id) + # Iterate over all neighbors of the current node + for neighbor_node_id in edges[node_id]: + # If a neighbor has not been visited yet, recursively visit it + if not visited[neighbor_node_id]: + dfs(neighbor_node_id, component) + + # Perform DFS on all nodes to find connected components + for i in range(len(nodes)): + # If a node has not been visited yet, start a new DFS from it + if not visited[i]: + component = [] + dfs(i, component) + # After visiting all reachable nodes, add the current component to the list + connected_components.append(component) + return connected_components + + +def map_runtime_aot_intermediate_outputs( + aot_intermediate_outputs: Dict[Tuple[int, ...], Any], + runtime_intermediate_outputs: Dict[Tuple[int, ...], Any], +) -> Dict[Tuple[Tuple[int, ...], Any], Tuple[Tuple[int, ...], Any]]: + """ + Map the runtime intermediate outputs to the AOT intermediate outputs + by finding overlapping debug handles and combining them into a single debug_handle + + Returns: + Dict[Tuple[Tuple[int, ...], Any], Tuple[Tuple[int, ...], Any]] - Mapping + from runtime intermediate output to AOT intermediate output + """ + # Merge overlapping debug handles + merge_overlapping_debug_handles(aot_intermediate_outputs) + merge_overlapping_debug_handles(runtime_intermediate_outputs) + + # Create a graph(nodes and edges) of overlapping(between aot and runtime) debug handles + nodes, edges = _create_debug_handle_overlap_graph( + aot_intermediate_outputs, runtime_intermediate_outputs + ) + # Find connected(between aot and runtime) components + connected_components = _find_connected_components(nodes, edges) + + aot_runtime_mapping = {} + for comp in connected_components: + # Separate nodes into AOT and runtime lists based on their source, + # each list is combined into a single element and mapped to each other. + aot_list = [ + (nodes[node_id].debug_handle, nodes[node_id].output) + for node_id in comp + if nodes[node_id].source == NodeSource.AOT + ] + runtime_list = [ + (nodes[node_id].debug_handle, nodes[node_id].output) + for node_id in comp + if nodes[node_id].source == NodeSource.RUNTIME + ] + + # Map only if both AOT and runtime data are present. + if len(aot_list) != 0 and len(runtime_list) != 0: + # Combine aot debug handles into a single key + aot_combined_debug_handle, aot_output = ( + _combine_overlapped_intermediate_outputs(aot_list) + ) + # Combine runtime debug handles into a single key + runtime_combined_debug_handle, runtime_output = ( + _combine_overlapped_intermediate_outputs(runtime_list) + ) + # Create a mapping between runtime and aot + aot_runtime_mapping[(aot_combined_debug_handle, aot_output)] = ( + runtime_combined_debug_handle, + runtime_output, + ) + + return aot_runtime_mapping diff --git a/devtools/inspector/tests/inspector_utils_test.py b/devtools/inspector/tests/inspector_utils_test.py index e6ccb1deda4..38ed2c29ea2 100644 --- a/devtools/inspector/tests/inspector_utils_test.py +++ b/devtools/inspector/tests/inspector_utils_test.py @@ -34,6 +34,7 @@ find_populated_event, gen_graphs_from_etrecord, is_inference_output_equal, + map_runtime_aot_intermediate_outputs, merge_overlapping_debug_handles, TimeScale, ) @@ -238,6 +239,84 @@ def test_merge_overlapping_debug_handles(self): self.assertEqual(intermediate_outputs, expected_intermediate_outputs) self.assertIs(expected_intermediate_outputs[(10, 11, 12)], big_tensor) + def test_map_runtime_aot_intermediate_outputs_empty_inputs(self): + # When the inputs are empty, the output should also be empty + aot_intermediate_outputs = {} + runtime_intermediate_outputs = {} + actual = map_runtime_aot_intermediate_outputs( + aot_intermediate_outputs, runtime_intermediate_outputs + ) + expected = {} + self.assertEqual(actual, expected) + + def test_map_runtime_aot_intermediate_outputs_single_element_tuple(self): + # Single element tuple + aot_intermediate_outputs = {(0,): 100, (1,): 200, (2,): 300} + runtime_intermediate_outputs = {(0,): 150, (1,): 250, (2,): 350} + actual = map_runtime_aot_intermediate_outputs( + aot_intermediate_outputs, runtime_intermediate_outputs + ) + expected = { + ((0,), 100): ((0,), 150), + ((1,), 200): ((1,), 250), + ((2,), 300): ((2,), 350), + } + self.assertEqual(actual, expected) + + def test_map_runtime_aot_intermediate_outputs_exact_match(self): + # Exact match between aot and runtime debug_handles + aot_intermediate_outputs = {(0, 1): 100, (2, 3): 200, (4, 5): 300} + runtime_intermediate_outputs = {(0, 1): 150, (2, 3): 200, (4, 5): 300} + actual = map_runtime_aot_intermediate_outputs( + aot_intermediate_outputs, runtime_intermediate_outputs + ) + expected = { + ((0, 1), 100): ((0, 1), 150), + ((2, 3), 200): ((2, 3), 200), + ((4, 5), 300): ((4, 5), 300), + } + self.assertEqual(actual, expected) + + def test_map_runtime_aot_intermediate_outputs_no_overlaps(self): + # No overlaps between aot and runtime debug_handles + aot_intermediate_outputs = {(0, 1): 100, (4, 5): 300} + runtime_intermediate_outputs = {(2, 3): 200, (8, 9): 300} + actual = map_runtime_aot_intermediate_outputs( + aot_intermediate_outputs, runtime_intermediate_outputs + ) + expected = {} + self.assertEqual(actual, expected) + + def test_map_runtime_aot_intermediate_outputs_multiple_aot_to_one_runtime(self): + # Multiple aot debug_handles map to one runtime debug_handle + aot_intermediate_outputs = {(0, 1, 2): 100, (3, 4): 300} + runtime_intermediate_outputs = {(1, 2, 3): 250, (8, 9): 300} + actual = map_runtime_aot_intermediate_outputs( + aot_intermediate_outputs, runtime_intermediate_outputs + ) + expected = {((0, 1, 2, 3, 4), 300): ((1, 2, 3), 250)} + self.assertEqual(actual, expected) + + def test_map_runtime_aot_intermediate_outputs_one_aot_to_multiple_runtime(self): + # One aot debug_handle map to multiple runtime debug_handles + aot_intermediate_outputs = {(0, 1, 2, 3, 4): 100, (8, 9): 300} + runtime_intermediate_outputs = {(0, 1): 150, (2, 3): 200, (4, 5): 300} + actual = map_runtime_aot_intermediate_outputs( + aot_intermediate_outputs, runtime_intermediate_outputs + ) + expected = {((0, 1, 2, 3, 4), 100): ((0, 1, 2, 3, 4, 5), 300)} + self.assertEqual(actual, expected) + + def test_map_runtime_aot_intermediate_outputs_complex_chain(self): + # Complex chain (N-to-N mapping) + aot_intermediate_outputs = {(1, 2): 100, (3, 4): 200, (5, 6): 300} + runtime_intermediate_outputs = {(2, 3): 150, (4, 5): 250, (6, 7): 350} + actual = map_runtime_aot_intermediate_outputs( + aot_intermediate_outputs, runtime_intermediate_outputs + ) + expected = {((1, 2, 3, 4, 5, 6), 300): ((2, 3, 4, 5, 6, 7), 350)} + self.assertEqual(actual, expected) + def gen_mock_operator_graph_with_expected_map() -> ( Tuple[OperatorGraph, Dict[int, OperatorNode]]