Skip to content

Commit ea5d350

Browse files
Juntian Liufacebook-github-bot
authored andcommitted
Mapping AOT debug_handles to op names
Summary: This PR adds a function to map AOT debug handles to operator names in the Export graph. It will be used later to enhance how numerical discrepancy results are shown, making it easier for users to understand. Differential Revision: D77244175
1 parent d5fe5fa commit ea5d350

File tree

3 files changed

+69
-0
lines changed

3 files changed

+69
-0
lines changed

devtools/inspector/_inspector.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
FORWARD,
5353
gen_etdump_object,
5454
gen_graphs_from_etrecord,
55+
get_aot_debug_handle_to_op_name_mapping,
5556
inflate_runtime_output,
5657
is_debug_output,
5758
is_inference_output_equal,
@@ -1084,6 +1085,7 @@ def __init__(
10841085
self._reference_outputs: Dict[str, List[ProgramOutput]] = {}
10851086
self._enable_module_hierarchy = enable_module_hierarchy
10861087
self._aot_intermediate_outputs: Optional[Dict[Tuple[int, ...], Any]] = None
1088+
self._aot_debug_handles_to_op_names: Optional[Dict[Tuple[int, ...], str]] = None
10871089
self._consume_etrecord()
10881090

10891091
def _consume_etrecord(self) -> None:
@@ -1150,6 +1152,9 @@ def _consume_etrecord(self) -> None:
11501152
return
11511153
export_program = self._etrecord.edge_dialect_program
11521154
graph_module = export_program.module()
1155+
self._aot_debug_handles_to_op_names = get_aot_debug_handle_to_op_name_mapping(
1156+
graph_module
1157+
)
11531158
capturer = IntermediateOutputCapturer(graph_module)
11541159
self._aot_intermediate_outputs = capturer.run_and_capture(
11551160
self._etrecord._representative_inputs

devtools/inspector/_inspector_utils.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -734,3 +734,27 @@ def convert_to_float_tensor(input_data: Any) -> torch.Tensor:
734734
if torch.isnan(input_tensor).any():
735735
input_tensor = torch.nan_to_num(input_tensor)
736736
return input_tensor
737+
738+
739+
def get_aot_debug_handle_to_op_name_mapping(
740+
graph_module: torch.fx.GraphModule,
741+
) -> Dict[Tuple[int, ...], str]:
742+
"""
743+
Get a mapping from debug handle to operator name from the ETRecord edge_dialect_program's graph module.
744+
Parameters:
745+
graph_module (torch.fx.GraphModule): The graph module to get the mapping from.
746+
Returns:
747+
Dict[Tuple[int, ...], str]: A dictionary mapping debug handles to operator names.
748+
"""
749+
debug_handle_to_op_name = {}
750+
for node in graph_module.graph.nodes:
751+
if "debug_handle" in node.meta:
752+
debug_handle = node.meta["debug_handle"]
753+
# Convert the debug handle to a tuple to use as a dictionary key
754+
key = (
755+
(debug_handle,)
756+
if isinstance(debug_handle, int)
757+
else tuple(debug_handle)
758+
)
759+
debug_handle_to_op_name[key] = node.name
760+
return debug_handle_to_op_name

devtools/inspector/tests/inspector_utils_test.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
EDGE_DIALECT_GRAPH_KEY,
3535
find_populated_event,
3636
gen_graphs_from_etrecord,
37+
get_aot_debug_handle_to_op_name_mapping,
3738
is_inference_output_equal,
3839
map_runtime_aot_intermediate_outputs,
3940
merge_overlapping_debug_handles,
@@ -364,6 +365,45 @@ class X:
364365
msg = str(cm.exception)
365366
self.assertIn("Cannot convert value of type", msg)
366367

368+
def test_get_aot_debug_handle_to_op_name_mapping_single_debug_handle(self):
369+
# Create a simple graph module with one node
370+
graph_module = torch.fx.GraphModule({}, torch.fx.Graph())
371+
node = graph_module.graph.create_node(
372+
"call_function", target=torch.mul, args=(), kwargs={}, name="op1"
373+
)
374+
node.meta["debug_handle"] = 1
375+
debug_handle_to_op_name = get_aot_debug_handle_to_op_name_mapping(graph_module)
376+
expected_result = {(1,): "op1"}
377+
self.assertEqual(debug_handle_to_op_name, expected_result)
378+
379+
def test_get_aot_debug_handle_to_op_name_mapping_multiple_debug_handles(self):
380+
# Create a simple graph module with two nodes
381+
graph_module = torch.fx.GraphModule({}, torch.fx.Graph())
382+
node1 = graph_module.graph.create_node(
383+
"call_function", target=torch.mul, args=(), kwargs={}, name="op1"
384+
)
385+
node1.meta["debug_handle"] = (1, 2)
386+
node2 = graph_module.graph.create_node(
387+
"call_function", target=torch.mul, args=(), kwargs={}, name="op2"
388+
)
389+
node2.meta["debug_handle"] = 3
390+
debug_handle_to_op_name = get_aot_debug_handle_to_op_name_mapping(graph_module)
391+
expected_result = {
392+
(
393+
1,
394+
2,
395+
): "op1",
396+
(3,): "op2",
397+
}
398+
self.assertEqual(debug_handle_to_op_name, expected_result)
399+
400+
def test_get_aot_debug_handle_to_op_name_mapping_no_debug_handles(self):
401+
# Create a simple graph module with no nodes
402+
graph_module = torch.fx.GraphModule({}, torch.fx.Graph())
403+
debug_handle_to_op_name = get_aot_debug_handle_to_op_name_mapping(graph_module)
404+
expected_result = {}
405+
self.assertEqual(debug_handle_to_op_name, expected_result)
406+
367407

368408
def gen_mock_operator_graph_with_expected_map() -> (
369409
Tuple[OperatorGraph, Dict[int, OperatorNode]]

0 commit comments

Comments
 (0)