Skip to content

Commit 422d801

Browse files
Juntian Liufacebook-github-bot
authored andcommitted
Mapping between runtime and aot intermediate outputs (#11624)
Summary: Pull Request resolved: #11624 This PR introduces a function map_runtime_aot_intermediate_outputs that maps runtime intermediate outputs to AOT intermediate outputs by identifying overlapping(between aot and runtime) debug handles and combining them into a single key that could handle below various mapping scenarios. 1. No Overlaps: There are no overlapping debug handles between AOT and runtime outputs. 2. 1-to-1 Mapping: A straightforward mapping where one AOT debug handle corresponds directly to one runtime debug handle. 3. 1-to-N Mapping: A single AOT debug handle maps to multiple runtime debug handles. 4. N-to-1 Mapping: Multiple AOT debug handles map to a single runtime debug handle. 5. N-to-N Mapping: More intricate scenarios where multiple AOT and runtime debug handles form a chain of overlaps In all cases where multiple debug handles are involved (N-to-1, 1-to-N, or N-to-N), the function merges these into a single combined debug handle and retains only the last intermediate output for each mapping. In order to handle all cases, the code first does some pre-processing to the input, then constructs a graph of nodes representing debug handles and outputs, identifies connected components using DFS, and creates mappings for overlapping components, merging debug handles and retaining the last output. This function will be used later in the Inspector Numerical Comparator class to create the mapping. Differential Revision: D76442807
1 parent b59f5cc commit 422d801

File tree

2 files changed

+255
-1
lines changed

2 files changed

+255
-1
lines changed

devtools/inspector/_inspector_utils.py

Lines changed: 176 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import math
1010
import sys
11+
from dataclasses import dataclass
1112
from enum import Enum
1213
from typing import Any, Dict, IO, List, Mapping, Optional, Tuple, TypeAlias, Union
1314

@@ -72,6 +73,17 @@ class TimeScale(Enum):
7273
}
7374

7475

76+
FROM_AOT = 1
77+
FROM_RUNTIME = 2
78+
79+
80+
@dataclass
81+
class node_data:
82+
source: int
83+
debug_handle: tuple[int]
84+
output: Any
85+
86+
7587
def calculate_time_scale_factor(
7688
source_time_scale: TimeScale, target_time_scale: TimeScale
7789
) -> float:
@@ -489,7 +501,7 @@ def merge_overlapping_debug_handles(intermediate_outputs: Dict[Tuple[int, ...],
489501
"""
490502
Merge overlapping debug handles int a single key
491503
"""
492-
if not intermediate_outputs:
504+
if len(intermediate_outputs) == 0:
493505
return
494506
# Extract and normalize into (start, end, val)
495507
intervals = [(min(key), max(key), val) for key, val in intermediate_outputs.items()]
@@ -512,3 +524,166 @@ def merge_overlapping_debug_handles(intermediate_outputs: Dict[Tuple[int, ...],
512524
intermediate_outputs.clear()
513525
for start, end, val in merged_intermediate_outputs:
514526
intermediate_outputs[tuple(range(start, end + 1))] = val
527+
528+
529+
def _debug_handles_have_overlap(
530+
aot_debug_hanlde: Tuple[int, ...], runtime_debug_handle: Tuple[int, ...]
531+
) -> bool:
532+
"""
533+
Check if the AOT debug handle and the runtime debug handle have any overlap.
534+
"""
535+
aot_set = set(aot_debug_hanlde)
536+
runtime_set = set(runtime_debug_handle)
537+
return len(aot_set.intersection(runtime_set)) > 0
538+
539+
540+
def _combine_debug_hanldes(debug_handles: List[Tuple[int, ...]]) -> Tuple[int, ...]:
541+
"""Combine multiple debug handles into one debug handle"""
542+
combined_debug_handles_set = set()
543+
for debug_handle in debug_handles:
544+
combined_debug_handles_set.update(set(debug_handle))
545+
return tuple(sorted(combined_debug_handles_set))
546+
547+
548+
def _combine_overlapped_intermediate_outputs(
549+
nodes: List[Tuple[Tuple[int, ...], Any]]
550+
) -> Tuple[Tuple[int, ...], Any]:
551+
"""Combine multiple overlapped intermediate outputs into one with combined debug_handles and last output"""
552+
debug_handles = [debug_handle for debug_handle, _ in nodes]
553+
outputs = [output for _, output in nodes]
554+
combined_debug_handle = _combine_debug_hanldes(debug_handles)
555+
output = outputs[-1] # Pick the last one
556+
return combined_debug_handle, output
557+
558+
559+
def _create_debug_handle_overlap_graph(
560+
aot_intermediate_outputs: Dict[Tuple[int, ...], Any],
561+
runtime_intermediate_outputs: Dict[Tuple[int, ...], Any],
562+
) -> Tuple[List[node_data], Dict[int, List[int]]]:
563+
"""
564+
Create a graph representing overlapping debug handles between AOT and runtime outputs.
565+
566+
Each node in the graph is an instance of NodeData, which contains:
567+
- source: A string indicating the origin of the node (either FROM_AOT or FROM_RUNTIME).
568+
- debug_handle: A tuple representing the unique identifier for the output.
569+
- output: The actual output data associated with the debug handle.
570+
571+
Edges in the graph are represented as a dictionary where:
572+
- The key is the index of a node in the nodes list.
573+
- The value is a list of indices of nodes that have overlapping debug handles with the key node.
574+
575+
Returns:
576+
- A tuple containing:
577+
- A list of NodeData instances representing the nodes in the graph.
578+
- A dictionary representing the edges, where each key-value pair indicates connected nodes due to overlapping debug handles.
579+
"""
580+
nodes = []
581+
for debug_handle, output in aot_intermediate_outputs.items():
582+
nodes.append(node_data(FROM_AOT, debug_handle, output))
583+
for debug_handle, output in runtime_intermediate_outputs.items():
584+
nodes.append(node_data(FROM_RUNTIME, debug_handle, output))
585+
586+
edges = {i: [] for i in range(len(nodes))}
587+
for i in range(len(nodes)):
588+
for j in range(i + 1, len(nodes)):
589+
node_i = nodes[i]
590+
node_j = nodes[j]
591+
# Only connect nodes from different sources(aot vs runtime) that overlap
592+
if node_i.source != node_j.source and _debug_handles_have_overlap(
593+
node_i.debug_handle, node_j.debug_handle
594+
):
595+
edges[i].append(j)
596+
edges[j].append(i)
597+
return (nodes, edges)
598+
599+
600+
def _find_connected_components(
601+
nodes: List[node_data], edges: Dict[int, List[int]]
602+
) -> List[List[int]]:
603+
"""
604+
Find groups of connected nodes in a graph using DFS.
605+
Parameters:
606+
- nodes: A list of nodes in the graph.
607+
- edges: A dictionary where each key is a node index, and the value is a list
608+
of indices of connected nodes.
609+
Returns:
610+
- A list of connected components, each represented as a list of node indices.
611+
"""
612+
visited = [False] * len(nodes)
613+
connected_components = []
614+
615+
def dfs(node_id, component):
616+
visited[node_id] = True
617+
component.append(node_id)
618+
# Iterate over all neighbors of the current node
619+
for neighbor_node_id in edges[node_id]:
620+
# If a neighbor has not been visited yet, recursively visit it
621+
if not visited[neighbor_node_id]:
622+
dfs(neighbor_node_id, component)
623+
624+
# Perform DFS on all nodes to find connected components
625+
for i in range(len(nodes)):
626+
# If a node has not been visited yet, start a new DFS from it
627+
if not visited[i]:
628+
component = []
629+
dfs(i, component)
630+
# After visiting all reachable nodes, add the current component to the list
631+
connected_components.append(component)
632+
return connected_components
633+
634+
635+
def map_runtime_aot_intermediate_outputs(
636+
aot_intermediate_outputs: Dict[Tuple[int, ...], Any],
637+
runtime_intermediate_outputs: Dict[Tuple[int, ...], Any],
638+
) -> Dict[Tuple[Tuple[int, ...], Any], Tuple[Tuple[int, ...], Any]]:
639+
"""
640+
Map the runtime intermediate outputs to the AOT intermediate outputs
641+
by finding overlapping debug handles and combining them into a single debug_handle
642+
643+
Returns:
644+
Dict[Tuple[Tuple[int, ...], Any], Tuple[Tuple[int, ...], Any]] - Mapping
645+
from runtime intermediate output to AOT intermediate output
646+
"""
647+
# Merge overlapping debug handles
648+
merge_overlapping_debug_handles(aot_intermediate_outputs)
649+
merge_overlapping_debug_handles(runtime_intermediate_outputs)
650+
651+
# Create a graph(nodes and edges) of overlapping(between aot and runtime) debug handles
652+
nodes, edges = _create_debug_handle_overlap_graph(
653+
aot_intermediate_outputs, runtime_intermediate_outputs
654+
)
655+
# Find connected(between aot and runtime) components
656+
connected_components = _find_connected_components(nodes, edges)
657+
658+
aot_runtime_mapping = {}
659+
for comp in connected_components:
660+
# Separate nodes into AOT and runtime lists based on their source,
661+
# each list is combined into a single element and mapped to each other.
662+
aot_list = [
663+
(nodes[node_id].debug_handle, nodes[node_id].output)
664+
for node_id in comp
665+
if nodes[node_id].source == FROM_AOT
666+
]
667+
runtime_list = [
668+
(nodes[node_id].debug_handle, nodes[node_id].output)
669+
for node_id in comp
670+
if nodes[node_id].source == FROM_RUNTIME
671+
]
672+
673+
# Map only if both AOT and runtime data are present.
674+
if len(aot_list) != 0 and len(runtime_list) != 0:
675+
# Combine aot debug handles into a single key
676+
aot_combined_debug_handle, aot_output = (
677+
_combine_overlapped_intermediate_outputs(aot_list)
678+
)
679+
# Combine runtime debug handles into a single key
680+
runtime_combined_debug_handle, runtime_output = (
681+
_combine_overlapped_intermediate_outputs(runtime_list)
682+
)
683+
# Create a mapping between runtime and aot
684+
aot_runtime_mapping[(aot_combined_debug_handle, aot_output)] = (
685+
runtime_combined_debug_handle,
686+
runtime_output,
687+
)
688+
689+
return aot_runtime_mapping

devtools/inspector/tests/inspector_utils_test.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
find_populated_event,
3535
gen_graphs_from_etrecord,
3636
is_inference_output_equal,
37+
map_runtime_aot_intermediate_outputs,
3738
merge_overlapping_debug_handles,
3839
TimeScale,
3940
)
@@ -238,6 +239,84 @@ def test_merge_overlapping_debug_handles(self):
238239
self.assertEqual(intermediate_outputs, expected_intermediate_outputs)
239240
self.assertIs(expected_intermediate_outputs[(10, 11, 12)], big_tensor)
240241

242+
def test_map_runtime_aot_intermediate_outputs_empty_inputs(self):
243+
# When the inputs are empty, the output should also be empty
244+
aot_intermediate_outputs = {}
245+
runtime_intermediate_outputs = {}
246+
actual = map_runtime_aot_intermediate_outputs(
247+
aot_intermediate_outputs, runtime_intermediate_outputs
248+
)
249+
expected = {}
250+
self.assertEqual(actual, expected)
251+
252+
def test_map_runtime_aot_intermediate_outputs_single_element_tuple(self):
253+
# Single element tuple
254+
aot_intermediate_outputs = {(0,): 100, (1,): 200, (2,): 300}
255+
runtime_intermediate_outputs = {(0,): 150, (1,): 250, (2,): 350}
256+
actual = map_runtime_aot_intermediate_outputs(
257+
aot_intermediate_outputs, runtime_intermediate_outputs
258+
)
259+
expected = {
260+
((0,), 100): ((0,), 150),
261+
((1,), 200): ((1,), 250),
262+
((2,), 300): ((2,), 350),
263+
}
264+
self.assertEqual(actual, expected)
265+
266+
def test_map_runtime_aot_intermediate_outputs_exact_match(self):
267+
# Exact match between aot and runtime debug_handles
268+
aot_intermediate_outputs = {(0, 1): 100, (2, 3): 200, (4, 5): 300}
269+
runtime_intermediate_outputs = {(0, 1): 150, (2, 3): 200, (4, 5): 300}
270+
actual = map_runtime_aot_intermediate_outputs(
271+
aot_intermediate_outputs, runtime_intermediate_outputs
272+
)
273+
expected = {
274+
((0, 1), 100): ((0, 1), 150),
275+
((2, 3), 200): ((2, 3), 200),
276+
((4, 5), 300): ((4, 5), 300),
277+
}
278+
self.assertEqual(actual, expected)
279+
280+
def test_map_runtime_aot_intermediate_outputs_no_overlaps(self):
281+
# No overlaps between aot and runtime debug_handles
282+
aot_intermediate_outputs = {(0, 1): 100, (4, 5): 300}
283+
runtime_intermediate_outputs = {(2, 3): 200, (8, 9): 300}
284+
actual = map_runtime_aot_intermediate_outputs(
285+
aot_intermediate_outputs, runtime_intermediate_outputs
286+
)
287+
expected = {}
288+
self.assertEqual(actual, expected)
289+
290+
def test_map_runtime_aot_intermediate_outputs_multiple_aot_to_one_runtime(self):
291+
# Multiple aot debug_handles map to one runtime debug_handle
292+
aot_intermediate_outputs = {(0, 1, 2): 100, (3, 4): 300}
293+
runtime_intermediate_outputs = {(1, 2, 3): 250, (8, 9): 300}
294+
actual = map_runtime_aot_intermediate_outputs(
295+
aot_intermediate_outputs, runtime_intermediate_outputs
296+
)
297+
expected = {((0, 1, 2, 3, 4), 300): ((1, 2, 3), 250)}
298+
self.assertEqual(actual, expected)
299+
300+
def test_map_runtime_aot_intermediate_outputs_one_aot_to_multiple_runtime(self):
301+
# One aot debug_handle map to multiple runtime debug_handles
302+
aot_intermediate_outputs = {(0, 1, 2, 3, 4): 100, (8, 9): 300}
303+
runtime_intermediate_outputs = {(0, 1): 150, (2, 3): 200, (4, 5): 300}
304+
actual = map_runtime_aot_intermediate_outputs(
305+
aot_intermediate_outputs, runtime_intermediate_outputs
306+
)
307+
expected = {((0, 1, 2, 3, 4), 100): ((0, 1, 2, 3, 4, 5), 300)}
308+
self.assertEqual(actual, expected)
309+
310+
def test_map_runtime_aot_intermediate_outputs_complex_chain(self):
311+
# Complex chain (N-to-N mapping)
312+
aot_intermediate_outputs = {(1, 2): 100, (3, 4): 200, (5, 6): 300}
313+
runtime_intermediate_outputs = {(2, 3): 150, (4, 5): 250, (6, 7): 350}
314+
actual = map_runtime_aot_intermediate_outputs(
315+
aot_intermediate_outputs, runtime_intermediate_outputs
316+
)
317+
expected = {((1, 2, 3, 4, 5, 6), 300): ((2, 3, 4, 5, 6, 7), 350)}
318+
self.assertEqual(actual, expected)
319+
241320

242321
def gen_mock_operator_graph_with_expected_map() -> (
243322
Tuple[OperatorGraph, Dict[int, OperatorNode]]

0 commit comments

Comments
 (0)