Skip to content

Commit da6beae

Browse files
Juntian Liufacebook-github-bot
authored andcommitted
Mapping between runtime and aot intermediate outputs
Summary: This PR introduces a function map_runtime_aot_intermediate_outputs that maps runtime intermediate outputs to AOT intermediate outputs by identifying overlapping(between aot and runtime) debug handles and combining them into a single key that could handle below various mapping scenarios. 1. No Overlaps: There are no overlapping debug handles between AOT and runtime outputs. 2. 1-to-1 Mapping: A straightforward mapping where one AOT debug handle corresponds directly to one runtime debug handle. 3. 1-to-N Mapping: A single AOT debug handle maps to multiple runtime debug handles. 4. N-to-1 Mapping: Multiple AOT debug handles map to a single runtime debug handle. 5. N-to-N Mapping: More intricate scenarios where multiple AOT and runtime debug handles form a chain of overlaps In all cases where multiple debug handles are involved (N-to-1, 1-to-N, or N-to-N), the function merges these into a single combined debug handle and retains only the last intermediate output for each mapping. In order to handle all cases, the code first does some pre-processing to the input, then constructs a graph of nodes representing debug handles and outputs, identifies connected components using DFS, and creates mappings for overlapping components, merging debug handles and retaining the last output. This function will be used later in the Inspector Numerical Comparator class to create the mapping. Differential Revision: D76442807
1 parent 8895573 commit da6beae

File tree

2 files changed

+237
-1
lines changed

2 files changed

+237
-1
lines changed

devtools/inspector/_inspector_utils.py

Lines changed: 158 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import math
1010
import sys
11+
from dataclasses import dataclass
1112
from enum import Enum
1213
from typing import Any, Dict, IO, List, Mapping, Optional, Tuple, TypeAlias, Union
1314

@@ -72,6 +73,17 @@ class TimeScale(Enum):
7273
}
7374

7475

76+
FROM_AOT = 1
77+
FROM_RUNTIME = 2
78+
79+
80+
@dataclass
81+
class node_data:
82+
source: int
83+
debug_handle: tuple[int]
84+
output: Any
85+
86+
7587
def calculate_time_scale_factor(
7688
source_time_scale: TimeScale, target_time_scale: TimeScale
7789
) -> float:
@@ -489,7 +501,7 @@ def merge_overlapping_debug_handles(intermediate_outputs: Dict[Tuple[int, ...],
489501
"""
490502
Merge overlapping debug handles int a single key
491503
"""
492-
if not intermediate_outputs:
504+
if len(intermediate_outputs) == 0:
493505
return
494506
# Extract and normalize into (start, end, val)
495507
intervals = [(min(key), max(key), val) for key, val in intermediate_outputs.items()]
@@ -512,3 +524,148 @@ def merge_overlapping_debug_handles(intermediate_outputs: Dict[Tuple[int, ...],
512524
intermediate_outputs.clear()
513525
for start, end, val in merged_intermediate_outputs:
514526
intermediate_outputs[tuple(range(start, end + 1))] = val
527+
528+
529+
def _has_overlaps(
530+
aot_debug_hanlde: Tuple[int, ...], runtime_debug_handle: Tuple[int, ...]
531+
) -> bool:
532+
"""
533+
Check if the AOT debug handle and runtime debug handle overlap
534+
"""
535+
aot_set = set(aot_debug_hanlde)
536+
runtime_set = set(runtime_debug_handle)
537+
return len(aot_set.intersection(runtime_set)) > 0
538+
539+
540+
def _combine_debug_hanldes(debug_handles: List[Tuple[int, ...]]) -> Tuple[int, ...]:
541+
"""Combine multiple debug handles into one debug handle"""
542+
combined_debug_handles_set = set()
543+
for debug_handle in debug_handles:
544+
combined_debug_handles_set.update(set(debug_handle))
545+
return tuple(sorted(combined_debug_handles_set))
546+
547+
548+
def _combine_overlapped_intermediate_outputs(
549+
nodes: List[Tuple[Tuple[int, ...], Any]]
550+
) -> Tuple[Tuple[int, ...], Any]:
551+
"""Combine multiple overlapped intermediate outputs into one with combined debug_handles and last output"""
552+
debug_handles = [debug_handle for debug_handle, _ in nodes]
553+
outputs = [output for _, output in nodes]
554+
combined_debug_handle = _combine_debug_hanldes(debug_handles)
555+
output = outputs[-1] # Pick the last one
556+
return combined_debug_handle, output
557+
558+
559+
def _create_graph(
560+
aot_intermediate_outputs: Dict[Tuple[int, ...], Any],
561+
runtime_intermediate_outputs: Dict[Tuple[int, ...], Any],
562+
) -> Tuple[List[node_data], Dict[int, List[int]]]:
563+
"""
564+
Create a graph of overlapping(between aot and runtime) debug handles
565+
Each node is (source, debug_handle, output) where source is FROM_AOT or FROM_RUNTIME
566+
ajd_list is a dictionary of node_id to a list of connected node_ids, it's used as edges in the graph
567+
"""
568+
nodes = []
569+
for debug_handle, output in aot_intermediate_outputs.items():
570+
nodes.append(node_data(FROM_AOT, debug_handle, output))
571+
for debug_handle, output in runtime_intermediate_outputs.items():
572+
nodes.append(node_data(FROM_RUNTIME, debug_handle, output))
573+
574+
edges = {i: [] for i in range(len(nodes))}
575+
for i in range(len(nodes)):
576+
for j in range(i + 1, len(nodes)):
577+
node_i = nodes[i]
578+
node_j = nodes[j]
579+
# Only connect nodes from different sources(aot vs runtime) that overlap
580+
if node_i.source != node_j.source and _has_overlaps(
581+
node_i.debug_handle, node_j.debug_handle
582+
):
583+
edges[i].append(j)
584+
edges[j].append(i)
585+
return (nodes, edges)
586+
587+
588+
def _find_connected_components(
589+
nodes: List[node_data], edges: Dict[int, List[int]]
590+
) -> List[List[int]]:
591+
"""
592+
Find connected(between aot and runtime) components using DFS
593+
"""
594+
595+
visited = [False] * len(nodes)
596+
connected_components = []
597+
598+
def dfs(node_id, component):
599+
visited[node_id] = True
600+
component.append(node_id)
601+
# Iterate over all neighbors of the current node
602+
for neighbor_node_id in edges[node_id]:
603+
# If a neighbor has not been visited yet, recursively visit it
604+
if not visited[neighbor_node_id]:
605+
dfs(neighbor_node_id, component)
606+
607+
# Perform DFS on all nodes to find connected components
608+
for i in range(len(nodes)):
609+
# If a node has not been visited yet, start a new DFS from it
610+
if not visited[i]:
611+
component = []
612+
dfs(i, component)
613+
# After visiting all reachable nodes, add the current component to the list
614+
connected_components.append(component)
615+
return connected_components
616+
617+
618+
def map_runtime_aot_intermediate_outputs(
619+
aot_intermediate_outputs: Dict[Tuple[int, ...], Any],
620+
runtime_intermediate_outputs: Dict[Tuple[int, ...], Any],
621+
) -> Dict[Tuple[Tuple[int, ...], Any], Tuple[Tuple[int, ...], Any]]:
622+
"""
623+
Map the runtime intermediate outputs to the AOT intermediate outputs
624+
by finding overlapping debug handles and combining them into a single debug_handle
625+
626+
Returns:
627+
Dict[Tuple[Tuple[int, ...], Any], Tuple[Tuple[int, ...], Any]] - Mapping
628+
from runtime intermediate output to AOT intermediate output
629+
"""
630+
# Merge overlapping debug handles
631+
merge_overlapping_debug_handles(aot_intermediate_outputs)
632+
merge_overlapping_debug_handles(runtime_intermediate_outputs)
633+
634+
# Create a graph(nodes and edges) of overlapping(between aot and runtime) debug handles
635+
nodes, edges = _create_graph(aot_intermediate_outputs, runtime_intermediate_outputs)
636+
# Find connected(between aot and runtime) components
637+
connected_components = _find_connected_components(nodes, edges)
638+
639+
aot_runtime_mapping = {}
640+
for comp in connected_components:
641+
# Separate nodes into AOT and runtime lists based on their source,
642+
# each list is combined into a single element and mapped to each other.
643+
aot_list = []
644+
runtime_list = []
645+
for node_id in comp:
646+
node = nodes[node_id]
647+
source = node.source
648+
debug_handle = node.debug_handle
649+
output = node.output
650+
if source == FROM_AOT:
651+
aot_list.append((debug_handle, output))
652+
else:
653+
runtime_list.append((debug_handle, output))
654+
655+
# Map only if both AOT and runtime data are present.
656+
if len(aot_list) != 0 and len(runtime_list) != 0:
657+
# Combine aot debug handles into a single key
658+
aot_combined_debug_handle, aot_output = (
659+
_combine_overlapped_intermediate_outputs(aot_list)
660+
)
661+
# Combine runtime debug handles into a single key
662+
runtime_combined_debug_handle, runtime_output = (
663+
_combine_overlapped_intermediate_outputs(runtime_list)
664+
)
665+
# Create a mapping between runtime and aot
666+
aot_runtime_mapping[(aot_combined_debug_handle, aot_output)] = (
667+
runtime_combined_debug_handle,
668+
runtime_output,
669+
)
670+
671+
return aot_runtime_mapping

devtools/inspector/tests/inspector_utils_test.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
find_populated_event,
3535
gen_graphs_from_etrecord,
3636
is_inference_output_equal,
37+
map_runtime_aot_intermediate_outputs,
3738
merge_overlapping_debug_handles,
3839
TimeScale,
3940
)
@@ -238,6 +239,84 @@ def test_merge_overlapping_debug_handles(self):
238239
self.assertEqual(intermediate_outputs, expected_intermediate_outputs)
239240
self.assertIs(expected_intermediate_outputs[(10, 11, 12)], big_tensor)
240241

242+
def test_map_runtime_aot_intermediate_outputs_empty_inputs(self):
243+
# When the inputs are empty, the output should also be empty
244+
aot_intermediate_outputs = {}
245+
runtime_intermediate_outputs = {}
246+
actual = map_runtime_aot_intermediate_outputs(
247+
aot_intermediate_outputs, runtime_intermediate_outputs
248+
)
249+
expected = {}
250+
self.assertEqual(actual, expected)
251+
252+
def test_map_runtime_aot_intermediate_outputs_single_element_tuple(self):
253+
# Single element tuple
254+
aot_intermediate_outputs = {(0,): 100, (1,): 200, (2,): 300}
255+
runtime_intermediate_outputs = {(0,): 150, (1,): 250, (2,): 350}
256+
actual = map_runtime_aot_intermediate_outputs(
257+
aot_intermediate_outputs, runtime_intermediate_outputs
258+
)
259+
expected = {
260+
((0,), 100): ((0,), 150),
261+
((1,), 200): ((1,), 250),
262+
((2,), 300): ((2,), 350),
263+
}
264+
self.assertEqual(actual, expected)
265+
266+
def test_map_runtime_aot_intermediate_outputs_exact_match(self):
267+
# Exact match between aot and runtime debug_handles
268+
aot_intermediate_outputs = {(0, 1): 100, (2, 3): 200, (4, 5): 300}
269+
runtime_intermediate_outputs = {(0, 1): 150, (2, 3): 200, (4, 5): 300}
270+
actual = map_runtime_aot_intermediate_outputs(
271+
aot_intermediate_outputs, runtime_intermediate_outputs
272+
)
273+
expected = {
274+
((0, 1), 100): ((0, 1), 150),
275+
((2, 3), 200): ((2, 3), 200),
276+
((4, 5), 300): ((4, 5), 300),
277+
}
278+
self.assertEqual(actual, expected)
279+
280+
def test_map_runtime_aot_intermediate_outputs_no_overlaps(self):
281+
# No overlaps between aot and runtime debug_handles
282+
aot_intermediate_outputs = {(0, 1): 100, (4, 5): 300}
283+
runtime_intermediate_outputs = {(2, 3): 200, (8, 9): 300}
284+
actual = map_runtime_aot_intermediate_outputs(
285+
aot_intermediate_outputs, runtime_intermediate_outputs
286+
)
287+
expected = {}
288+
self.assertEqual(actual, expected)
289+
290+
def test_map_runtime_aot_intermediate_outputs_multiple_aot_to_one_runtime(self):
291+
# Multiple aot debug_handles map to one runtime debug_handle
292+
aot_intermediate_outputs = {(0, 1, 2): 100, (3, 4): 300}
293+
runtime_intermediate_outputs = {(1, 2, 3): 250, (8, 9): 300}
294+
actual = map_runtime_aot_intermediate_outputs(
295+
aot_intermediate_outputs, runtime_intermediate_outputs
296+
)
297+
expected = {((0, 1, 2, 3, 4), 300): ((1, 2, 3), 250)}
298+
self.assertEqual(actual, expected)
299+
300+
def test_map_runtime_aot_intermediate_outputs_one_aot_to_multiple_runtime(self):
301+
# One aot debug_handle map to multiple runtime debug_handles
302+
aot_intermediate_outputs = {(0, 1, 2, 3, 4): 100, (8, 9): 300}
303+
runtime_intermediate_outputs = {(0, 1): 150, (2, 3): 200, (4, 5): 300}
304+
actual = map_runtime_aot_intermediate_outputs(
305+
aot_intermediate_outputs, runtime_intermediate_outputs
306+
)
307+
expected = {((0, 1, 2, 3, 4), 100): ((0, 1, 2, 3, 4, 5), 300)}
308+
self.assertEqual(actual, expected)
309+
310+
def test_map_runtime_aot_intermediate_outputs_complex_chain(self):
311+
# Complex chain (N-to-N mapping)
312+
aot_intermediate_outputs = {(1, 2): 100, (3, 4): 200, (5, 6): 300}
313+
runtime_intermediate_outputs = {(2, 3): 150, (4, 5): 250, (6, 7): 350}
314+
actual = map_runtime_aot_intermediate_outputs(
315+
aot_intermediate_outputs, runtime_intermediate_outputs
316+
)
317+
expected = {((1, 2, 3, 4, 5, 6), 300): ((2, 3, 4, 5, 6, 7), 350)}
318+
self.assertEqual(actual, expected)
319+
241320

242321
def gen_mock_operator_graph_with_expected_map() -> (
243322
Tuple[OperatorGraph, Dict[int, OperatorNode]]

0 commit comments

Comments
 (0)