Skip to content

Commit 7bcb6d3

Browse files
authored
Allow for matching debug handles with partial overlap between aten graph and runtime
Differential Revision: D82229367 Pull Request resolved: #14306
1 parent 03f436a commit 7bcb6d3

File tree

2 files changed

+31
-14
lines changed

2 files changed

+31
-14
lines changed

devtools/inspector/_inspector_utils.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -657,13 +657,21 @@ def _combine_aot_overlapped_intermediate_outputs(
657657
# Combine all AOT debug_handles into a list
658658
aot_combined_debug_handle = [t[0] for t in aot_map.keys()]
659659

660-
if set(aot_combined_debug_handle) != set(runtime_debug_handle):
661-
# AOT combined debug_handle and runtime debug_handle do not match.
660+
# Reason we dont check for exact match:
661+
# in some experiments where we want to rewrite the aten graph that was
662+
# lowered, so as to use custom ops like int4_matmul, we lose some nodes
663+
# on the graph and thus lose some debug handles. And we dont find
664+
# exact match within connected components.
665+
if not set(aot_combined_debug_handle).issubset(set(runtime_debug_handle)):
666+
# AOT combined debug_handle is not a subset of runtime debug_handle.
662667
return (-1,), None
663668

664669
# Pick the last intermediate output
665670
last_int = runtime_debug_handle[negative_index]
666671
key = (last_int,)
672+
if key not in aot_map:
673+
# If the last intermediate output is not in the AOT map, return None
674+
return (-1,), None
667675
return runtime_debug_handle, aot_map[key]
668676

669677

@@ -1059,11 +1067,16 @@ def _find_n_match_node(node: Node) -> None:
10591067
if node.op in ("output", "placeholder"):
10601068
return
10611069
node_id = f"{node.name}.{exported_program_graph_id}"
1062-
parent_node_id = get_parent_node_identifier(node)
1070+
parent_node_ids = get_ancestor_node_identifiers(node)
10631071
if node_id in ancestors_node_id_to_debug_handle:
10641072
matched_debug_handles.add(ancestors_node_id_to_debug_handle[node_id])
1065-
elif parent_node_id and parent_node_id in ancestors_node_id_to_debug_handle:
1066-
matched_debug_handles.add(ancestors_node_id_to_debug_handle[parent_node_id])
1073+
elif parent_node_ids:
1074+
for parent_node_id in parent_node_ids:
1075+
if parent_node_id in ancestors_node_id_to_debug_handle:
1076+
matched_debug_handles.add(
1077+
ancestors_node_id_to_debug_handle[parent_node_id]
1078+
)
1079+
break
10671080

10681081
bfs_trace_with_node_process(exported_program.graph_module, _find_n_match_node)
10691082
return matched_debug_handles
@@ -1097,15 +1110,17 @@ def _equip_debug_handle(node: Node) -> None:
10971110
if node.op in ("output", "placeholder"):
10981111
return
10991112
node_id = f"{node.name}.{exported_program_graph_id}"
1100-
parent_node_id = get_parent_node_identifier(node)
1113+
parent_node_ids = get_ancestor_node_identifiers(node)
1114+
node.meta[DEBUG_HANDLE_KEY] = UNSET_DEBUG_HANDLE
11011115
if node_id in ancestors_node_id_to_debug_handle:
11021116
node.meta[DEBUG_HANDLE_KEY] = ancestors_node_id_to_debug_handle[node_id]
1103-
elif parent_node_id and parent_node_id in ancestors_node_id_to_debug_handle:
1104-
node.meta[DEBUG_HANDLE_KEY] = ancestors_node_id_to_debug_handle[
1105-
parent_node_id
1106-
]
1107-
else:
1108-
node.meta[DEBUG_HANDLE_KEY] = UNSET_DEBUG_HANDLE
1117+
elif parent_node_ids:
1118+
for parent_node_id in parent_node_ids:
1119+
if parent_node_id in ancestors_node_id_to_debug_handle:
1120+
node.meta[DEBUG_HANDLE_KEY] = ancestors_node_id_to_debug_handle[
1121+
parent_node_id
1122+
]
1123+
break
11091124

11101125
bfs_trace_with_node_process(exported_program.graph_module, _equip_debug_handle)
11111126

devtools/inspector/tests/inspector_utils_test.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -334,13 +334,15 @@ def test_map_runtime_aot_intermediate_outputs_no_overlaps(self):
334334
self.assertEqual(actual, expected)
335335

336336
def test_map_runtime_aot_intermediate_outputs_partial_match(self):
337-
# Partial match between aot and runtime debug_handles will return empty
337+
# Partial match between aot and runtime debug_handles will return
338+
# matching debug handles from runtime
338339
aot_intermediate_outputs = {(2,): 100, (9,): 300}
339340
runtime_intermediate_outputs = {(2, 3): (200, 1), (8, 9): (300, 1)}
340341
actual = map_runtime_aot_intermediate_outputs(
341342
aot_intermediate_outputs, runtime_intermediate_outputs
342343
)
343-
expected = {}
344+
# Since the runtime output debug handle of 9 is there in aot debug handle
345+
expected = {((8, 9), 300): ((8, 9), 300)}
344346
self.assertEqual(actual, expected)
345347

346348
def test_map_runtime_aot_intermediate_outputs_multiple_aot_to_one_runtime(self):

0 commit comments

Comments
 (0)