Skip to content

Commit 0a0f986

Browse files
Juntian Liufacebook-github-bot
authored andcommitted
Add function to map AOT debug_handles to op names (#11930)
Summary: This PR adds a function to map AOT debug handles to operator names in the Export graph. It will be used later to enhance how numerical discrepancy results are shown, making it easier for users to understand. Reviewed By: Gasoonjia Differential Revision: D77244175
1 parent 42195da commit 0a0f986

File tree

5 files changed

+166
-23
lines changed

5 files changed

+166
-23
lines changed

devtools/inspector/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ python_library(
5656
"_intermediate_output_capturer.py",
5757
],
5858
deps = [
59+
"//executorch/devtools/inspector:inspector_utils",
5960
],
6061
)
6162

devtools/inspector/_inspector.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
FORWARD,
5353
gen_etdump_object,
5454
gen_graphs_from_etrecord,
55+
get_aot_debug_handle_to_op_name_mapping,
5556
inflate_runtime_output,
5657
is_debug_output,
5758
is_inference_output_equal,
@@ -1084,6 +1085,7 @@ def __init__(
10841085
self._reference_outputs: Dict[str, List[ProgramOutput]] = {}
10851086
self._enable_module_hierarchy = enable_module_hierarchy
10861087
self._aot_intermediate_outputs: Optional[Dict[Tuple[int, ...], Any]] = None
1088+
self._aot_debug_handles_to_op_names: Optional[Dict[Tuple[int, ...], str]] = None
10871089
self._consume_etrecord()
10881090

10891091
def _consume_etrecord(self) -> None:
@@ -1150,6 +1152,9 @@ def _consume_etrecord(self) -> None:
11501152
return
11511153
export_program = self._etrecord.edge_dialect_program
11521154
graph_module = export_program.module()
1155+
self._aot_debug_handles_to_op_names = get_aot_debug_handle_to_op_name_mapping(
1156+
graph_module
1157+
)
11531158
capturer = IntermediateOutputCapturer(graph_module)
11541159
self._aot_intermediate_outputs = capturer.run_and_capture(
11551160
self._etrecord._representative_inputs

devtools/inspector/_inspector_utils.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,28 @@ class NodeData:
9393
output: Any
9494

9595

96+
class NodeFilter:
97+
"""
98+
A class used to filter nodes based on extensible criteria.
99+
Attributes:
100+
metadata_key (str): The key to look for in the node's metadata.
101+
op_type (str): The operation code to match.
102+
exclude_ops (List[str]): A list of operations to exclude from the filter.
103+
"""
104+
105+
def __init__(self, metadata_key: str, op_type: str, exclude_ops: List[str] = None):
106+
self.metadata_key = metadata_key
107+
self.op_type = op_type
108+
self.exclude_ops = exclude_ops
109+
110+
def matches(self, node: torch.fx.Node) -> bool:
111+
return (
112+
node.meta.get(self.metadata_key) is not None
113+
and node.op == self.op_type
114+
and all(exclude_name not in node.name for exclude_name in self.exclude_ops)
115+
)
116+
117+
96118
def calculate_time_scale_factor(
97119
source_time_scale: TimeScale, target_time_scale: TimeScale
98120
) -> float:
@@ -734,3 +756,31 @@ def convert_to_float_tensor(input_data: Any) -> torch.Tensor:
734756
if torch.isnan(input_tensor).any():
735757
input_tensor = torch.nan_to_num(input_tensor)
736758
return input_tensor
759+
760+
761+
def get_aot_debug_handle_to_op_name_mapping(
762+
graph_module: torch.fx.GraphModule,
763+
) -> Dict[Tuple[int, ...], str]:
764+
"""
765+
Get a mapping from debug handle to operator name from the ETRecord edge_dialect_program's graph module.
766+
Parameters:
767+
graph_module (torch.fx.GraphModule): The graph module to get the mapping from.
768+
Returns:
769+
Dict[Tuple[int, ...], str]: A dictionary mapping debug handles to operator names.
770+
"""
771+
node_filters = [
772+
NodeFilter("debug_handle", "call_function", exclude_ops=["getitem"])
773+
]
774+
775+
debug_handle_to_op_name = {}
776+
for node in graph_module.graph.nodes:
777+
if all(filter.matches(node) for filter in node_filters):
778+
debug_handle = node.meta["debug_handle"]
779+
# Convert the debug handle to a tuple to use as a dictionary key
780+
key = (
781+
(debug_handle,)
782+
if isinstance(debug_handle, int)
783+
else tuple(debug_handle)
784+
)
785+
debug_handle_to_op_name[key] = node.name
786+
return debug_handle_to_op_name

devtools/inspector/_intermediate_output_capturer.py

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -7,35 +7,14 @@
77
# pyre-unsafe
88

99

10-
from typing import Any, Dict, List, Tuple
10+
from typing import Any, Dict, Tuple
1111

1212
import torch
13+
from executorch.devtools.inspector._inspector_utils import NodeFilter
1314
from torch.fx import GraphModule
1415
from torch.fx.interpreter import Interpreter
1516

1617

17-
class NodeFilter:
18-
"""
19-
A class used to filter nodes based on extensible criteria.
20-
Attributes:
21-
metadata_key (str): The key to look for in the node's metadata.
22-
op_type (str): The operation code to match.
23-
exclude_ops (List[str]): A list of operations to exclude from the filter.
24-
"""
25-
26-
def __init__(self, metadata_key: str, op_type: str, exclude_ops: List[str] = None):
27-
self.metadata_key = metadata_key
28-
self.op_type = op_type
29-
self.exclude_ops = exclude_ops
30-
31-
def matches(self, node: torch.fx.Node) -> bool:
32-
return (
33-
node.meta.get(self.metadata_key) is not None
34-
and node.op == self.op_type
35-
and all(exclude_name not in node.name for exclude_name in self.exclude_ops)
36-
)
37-
38-
3918
class IntermediateOutputCapturer(Interpreter):
4019
"""
4120
A class that captures intermediate outputs from a PyTorch graph module.

devtools/inspector/tests/inspector_utils_test.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,11 @@
3434
EDGE_DIALECT_GRAPH_KEY,
3535
find_populated_event,
3636
gen_graphs_from_etrecord,
37+
get_aot_debug_handle_to_op_name_mapping,
3738
is_inference_output_equal,
3839
map_runtime_aot_intermediate_outputs,
3940
merge_overlapping_debug_handles,
41+
NodeFilter,
4042
TimeScale,
4143
)
4244

@@ -364,6 +366,112 @@ class X:
364366
msg = str(cm.exception)
365367
self.assertIn("Cannot convert value of type", msg)
366368

369+
def test_get_aot_debug_handle_to_op_name_mapping_single_debug_handle(self):
370+
# Create a simple graph module with one node
371+
graph_module = torch.fx.GraphModule({}, torch.fx.Graph())
372+
node = graph_module.graph.create_node(
373+
"call_function", target=torch.mul, args=(), kwargs={}, name="op1"
374+
)
375+
node.meta["debug_handle"] = 1
376+
debug_handle_to_op_name = get_aot_debug_handle_to_op_name_mapping(graph_module)
377+
expected_result = {(1,): "op1"}
378+
self.assertEqual(debug_handle_to_op_name, expected_result)
379+
380+
def test_get_aot_debug_handle_to_op_name_mapping_multiple_debug_handles(self):
381+
# Create a simple graph module with two nodes
382+
graph_module = torch.fx.GraphModule({}, torch.fx.Graph())
383+
node1 = graph_module.graph.create_node(
384+
"call_function", target=torch.mul, args=(), kwargs={}, name="op1"
385+
)
386+
node1.meta["debug_handle"] = (1, 2)
387+
node2 = graph_module.graph.create_node(
388+
"call_function", target=torch.mul, args=(), kwargs={}, name="op2"
389+
)
390+
node2.meta["debug_handle"] = 3
391+
debug_handle_to_op_name = get_aot_debug_handle_to_op_name_mapping(graph_module)
392+
expected_result = {
393+
(
394+
1,
395+
2,
396+
): "op1",
397+
(3,): "op2",
398+
}
399+
self.assertEqual(debug_handle_to_op_name, expected_result)
400+
401+
def test_get_aot_debug_handle_to_op_name_mapping_no_debug_handles(self):
402+
# Create a simple graph module with no nodes
403+
graph_module = torch.fx.GraphModule({}, torch.fx.Graph())
404+
debug_handle_to_op_name = get_aot_debug_handle_to_op_name_mapping(graph_module)
405+
expected_result = {}
406+
self.assertEqual(debug_handle_to_op_name, expected_result)
407+
408+
def test_node_filter_match(self):
409+
node_filter = NodeFilter(
410+
"debug_handle", "call_function", exclude_ops=["getitem"]
411+
)
412+
413+
# Create a mock node that matches the filter criteria
414+
mock_node = torch.fx.Node(
415+
graph=torch.fx.Graph(),
416+
name="mock_node",
417+
op="call_function",
418+
target=torch.nn.functional.relu,
419+
args=(),
420+
kwargs={},
421+
)
422+
mock_node.meta["debug_handle"] = (1, 2)
423+
# Test that the filter matches the mock node
424+
self.assertTrue(node_filter.matches(mock_node))
425+
426+
def test_node_filter_key_mismatch(self):
427+
node_filter = NodeFilter(
428+
"debug_handle", "call_function", exclude_ops=["getitem"]
429+
)
430+
mock_node_metadata_key_mismatch = torch.fx.Node(
431+
graph=torch.fx.Graph(),
432+
name="mock_node_metadata_key_mismatch",
433+
op="call_function",
434+
target=torch.nn.functional.relu,
435+
args=(),
436+
kwargs={},
437+
)
438+
# Test that the filter doesn't match the mock node (meta doesn't have debug_handle key)
439+
self.assertFalse(node_filter.matches(mock_node_metadata_key_mismatch))
440+
441+
def test_node_filter_ops_mismatch(self):
442+
node_filter = NodeFilter(
443+
"debug_handle", "call_function", exclude_ops=["getitem"]
444+
)
445+
446+
mock_node_exclude_ops_mismatch = torch.fx.Node(
447+
graph=torch.fx.Graph(),
448+
name="getitem",
449+
op="call_function",
450+
target=torch.nn.functional.relu,
451+
args=(),
452+
kwargs={},
453+
)
454+
mock_node_exclude_ops_mismatch.meta["debug_handle"] = (1, 2)
455+
# Test that the filter doesn't match the mock node (exclude_ops mismatch)
456+
self.assertFalse(node_filter.matches(mock_node_exclude_ops_mismatch))
457+
458+
def test_node_op_type_mismatch(self):
459+
node_filter = NodeFilter(
460+
"debug_handle", "call_function", exclude_ops=["getitem"]
461+
)
462+
463+
mock_node_op_type_mismatch = torch.fx.Node(
464+
graph=torch.fx.Graph(),
465+
name="mock_node_op_type_mismatch",
466+
op="get_attr",
467+
target="torch.nn.functional.relu",
468+
args=(),
469+
kwargs={},
470+
)
471+
mock_node_op_type_mismatch.meta["debug_handle"] = (1, 2)
472+
# Test that the filter doesn't match the mock node (op_type mismatch)
473+
self.assertFalse(node_filter.matches(mock_node_op_type_mismatch))
474+
367475

368476
def gen_mock_operator_graph_with_expected_map() -> (
369477
Tuple[OperatorGraph, Dict[int, OperatorNode]]

0 commit comments

Comments
 (0)