Skip to content

Commit 6b7115e

Browse files
kimishpatelfacebook-github-bot
authored andcommitted
Allow for matching debug handles with partial overlap between aten graph and runtime
Summary: When aten graph is modified for debug, for instance using int4 matmul, it wont have complete overlap with debug handles recorded by the delegate. For example, original model will have chose_qparams,q, dq, dq, linear nodes. Delegate will record debug hanlde for all of those. Say those are (4, 5, 6, 7, 8). When int4 matmul rewrite pass, from torchao, is applied, we just inherit from_node information from linear node. Thus only the last debug handle 8 is associated with custom op int4 node. Thus when we map delegate debug handles with custom op we find overlap for 8 only. This diff allows to look for overlapping match instead of exact match. Plus it also changes the code for AOT debug handle so that we can look for all ancestor nodes instead of just parent node. This is also needed so as to allow for numerical comparison despite passes applied on original aten graph. Reviewed By: Gasoonjia Differential Revision: D82229367
1 parent 0f066e0 commit 6b7115e

File tree

2 files changed

+29
-13
lines changed

2 files changed

+29
-13
lines changed

devtools/inspector/_inspector_utils.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -657,13 +657,21 @@ def _combine_aot_overlapped_intermediate_outputs(
657657
# Combine all AOT debug_handles into a list
658658
aot_combined_debug_handle = [t[0] for t in aot_map.keys()]
659659

660-
if set(aot_combined_debug_handle) != set(runtime_debug_handle):
661-
# AOT combined debug_handle and runtime debug_handle do not match.
660+
# Reason we dont check for exact match:
661+
# in some experiments where we want to rewrite the aten graph that was
662+
# lowered, so as to use custom ops like int4_matmul, we lose some nodes
663+
# on the graph and thus lose some debug handles. And we dont find
664+
# exact match within connected components.
665+
if not set(aot_combined_debug_handle).issubset(set(runtime_debug_handle)):
666+
# AOT combined debug_handle is not a subset of runtime debug_handle.
662667
return (-1,), None
663668

664669
# Pick the last intermediate output
665670
last_int = runtime_debug_handle[negative_index]
666671
key = (last_int,)
672+
if key not in aot_map:
673+
# If the last intermediate output is not in the AOT map, return None
674+
return (-1,), None
667675
return runtime_debug_handle, aot_map[key]
668676

669677

@@ -1059,11 +1067,16 @@ def _find_n_match_node(node: Node) -> None:
10591067
if node.op in ("output", "placeholder"):
10601068
return
10611069
node_id = f"{node.name}.{exported_program_graph_id}"
1062-
parent_node_id = get_parent_node_identifier(node)
1070+
parent_node_ids = get_ancestor_node_identifiers(node)
10631071
if node_id in ancestors_node_id_to_debug_handle:
10641072
matched_debug_handles.add(ancestors_node_id_to_debug_handle[node_id])
1065-
elif parent_node_id and parent_node_id in ancestors_node_id_to_debug_handle:
1066-
matched_debug_handles.add(ancestors_node_id_to_debug_handle[parent_node_id])
1073+
elif parent_node_ids:
1074+
for parent_node_id in parent_node_ids:
1075+
if parent_node_id in ancestors_node_id_to_debug_handle:
1076+
matched_debug_handles.add(
1077+
ancestors_node_id_to_debug_handle[parent_node_id]
1078+
)
1079+
break
10671080

10681081
bfs_trace_with_node_process(exported_program.graph_module, _find_n_match_node)
10691082
return matched_debug_handles
@@ -1097,15 +1110,17 @@ def _equip_debug_handle(node: Node) -> None:
10971110
if node.op in ("output", "placeholder"):
10981111
return
10991112
node_id = f"{node.name}.{exported_program_graph_id}"
1100-
parent_node_id = get_parent_node_identifier(node)
1113+
parent_node_ids = get_ancestor_node_identifiers(node)
1114+
node.meta[DEBUG_HANDLE_KEY] = UNSET_DEBUG_HANDLE
11011115
if node_id in ancestors_node_id_to_debug_handle:
11021116
node.meta[DEBUG_HANDLE_KEY] = ancestors_node_id_to_debug_handle[node_id]
1103-
elif parent_node_id and parent_node_id in ancestors_node_id_to_debug_handle:
1104-
node.meta[DEBUG_HANDLE_KEY] = ancestors_node_id_to_debug_handle[
1105-
parent_node_id
1106-
]
1107-
else:
1108-
node.meta[DEBUG_HANDLE_KEY] = UNSET_DEBUG_HANDLE
1117+
elif parent_node_ids:
1118+
for parent_node_id in parent_node_ids:
1119+
if parent_node_id in ancestors_node_id_to_debug_handle:
1120+
node.meta[DEBUG_HANDLE_KEY] = ancestors_node_id_to_debug_handle[
1121+
parent_node_id
1122+
]
1123+
break
11091124

11101125
bfs_trace_with_node_process(exported_program.graph_module, _equip_debug_handle)
11111126

devtools/inspector/tests/inspector_utils_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,8 @@ def test_map_runtime_aot_intermediate_outputs_partial_match(self):
340340
actual = map_runtime_aot_intermediate_outputs(
341341
aot_intermediate_outputs, runtime_intermediate_outputs
342342
)
343-
expected = {}
343+
# Since the runtime output debug handle of 9 is there in aot debug handle
344+
expected = {((8, 9), 300): ((8, 9), 300)}
344345
self.assertEqual(actual, expected)
345346

346347
def test_map_runtime_aot_intermediate_outputs_multiple_aot_to_one_runtime(self):

0 commit comments

Comments
 (0)