Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions devtools/inspector/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ python_library(
"_intermediate_output_capturer.py",
],
deps = [
"//executorch/devtools/inspector:inspector_utils",
],
)

Expand Down
5 changes: 5 additions & 0 deletions devtools/inspector/_inspector.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
FORWARD,
gen_etdump_object,
gen_graphs_from_etrecord,
get_aot_debug_handle_to_op_name_mapping,
inflate_runtime_output,
is_debug_output,
is_inference_output_equal,
Expand Down Expand Up @@ -1084,6 +1085,7 @@ def __init__(
self._reference_outputs: Dict[str, List[ProgramOutput]] = {}
self._enable_module_hierarchy = enable_module_hierarchy
self._aot_intermediate_outputs: Optional[Dict[Tuple[int, ...], Any]] = None
self._aot_debug_handles_to_op_names: Optional[Dict[Tuple[int, ...], str]] = None
self._consume_etrecord()

def _consume_etrecord(self) -> None:
Expand Down Expand Up @@ -1150,6 +1152,9 @@ def _consume_etrecord(self) -> None:
return
export_program = self._etrecord.edge_dialect_program
graph_module = export_program.module()
self._aot_debug_handles_to_op_names = get_aot_debug_handle_to_op_name_mapping(
graph_module
)
capturer = IntermediateOutputCapturer(graph_module)
self._aot_intermediate_outputs = capturer.run_and_capture(
self._etrecord._representative_inputs
Expand Down
50 changes: 50 additions & 0 deletions devtools/inspector/_inspector_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,28 @@ class NodeData:
output: Any


class NodeFilter:
"""
A class used to filter nodes based on extensible criteria.
Attributes:
metadata_key (str): The key to look for in the node's metadata.
op_type (str): The operation code to match.
exclude_ops (List[str]): A list of operations to exclude from the filter.
"""

def __init__(self, metadata_key: str, op_type: str, exclude_ops: List[str] = None):
self.metadata_key = metadata_key
self.op_type = op_type
self.exclude_ops = exclude_ops

def matches(self, node: torch.fx.Node) -> bool:
return (
node.meta.get(self.metadata_key) is not None
and node.op == self.op_type
and all(exclude_name not in node.name for exclude_name in self.exclude_ops)
)


def calculate_time_scale_factor(
source_time_scale: TimeScale, target_time_scale: TimeScale
) -> float:
Expand Down Expand Up @@ -734,3 +756,31 @@ def convert_to_float_tensor(input_data: Any) -> torch.Tensor:
if torch.isnan(input_tensor).any():
input_tensor = torch.nan_to_num(input_tensor)
return input_tensor


def get_aot_debug_handle_to_op_name_mapping(
graph_module: torch.fx.GraphModule,
) -> Dict[Tuple[int, ...], str]:
"""
Get a mapping from debug handle to operator name from the ETRecord edge_dialect_program's graph module.
Parameters:
graph_module (torch.fx.GraphModule): The graph module to get the mapping from.
Returns:
Dict[Tuple[int, ...], str]: A dictionary mapping debug handles to operator names.
"""
node_filters = [
NodeFilter("debug_handle", "call_function", exclude_ops=["getitem"])
]

debug_handle_to_op_name = {}
for node in graph_module.graph.nodes:
if all(filter.matches(node) for filter in node_filters):
debug_handle = node.meta["debug_handle"]
# Convert the debug handle to a tuple to use as a dictionary key
key = (
(debug_handle,)
if isinstance(debug_handle, int)
else tuple(debug_handle)
)
debug_handle_to_op_name[key] = node.name
return debug_handle_to_op_name
25 changes: 2 additions & 23 deletions devtools/inspector/_intermediate_output_capturer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,35 +7,14 @@
# pyre-unsafe


from typing import Any, Dict, List, Tuple
from typing import Any, Dict, Tuple

import torch
from executorch.devtools.inspector._inspector_utils import NodeFilter
from torch.fx import GraphModule
from torch.fx.interpreter import Interpreter


class NodeFilter:
"""
A class used to filter nodes based on extensible criteria.
Attributes:
metadata_key (str): The key to look for in the node's metadata.
op_type (str): The operation code to match.
exclude_ops (List[str]): A list of operations to exclude from the filter.
"""

def __init__(self, metadata_key: str, op_type: str, exclude_ops: List[str] = None):
self.metadata_key = metadata_key
self.op_type = op_type
self.exclude_ops = exclude_ops

def matches(self, node: torch.fx.Node) -> bool:
return (
node.meta.get(self.metadata_key) is not None
and node.op == self.op_type
and all(exclude_name not in node.name for exclude_name in self.exclude_ops)
)


class IntermediateOutputCapturer(Interpreter):
"""
A class that captures intermediate outputs from a PyTorch graph module.
Expand Down
108 changes: 108 additions & 0 deletions devtools/inspector/tests/inspector_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,11 @@
EDGE_DIALECT_GRAPH_KEY,
find_populated_event,
gen_graphs_from_etrecord,
get_aot_debug_handle_to_op_name_mapping,
is_inference_output_equal,
map_runtime_aot_intermediate_outputs,
merge_overlapping_debug_handles,
NodeFilter,
TimeScale,
)

Expand Down Expand Up @@ -364,6 +366,112 @@ class X:
msg = str(cm.exception)
self.assertIn("Cannot convert value of type", msg)

def test_get_aot_debug_handle_to_op_name_mapping_single_debug_handle(self):
# Create a simple graph module with one node
graph_module = torch.fx.GraphModule({}, torch.fx.Graph())
node = graph_module.graph.create_node(
"call_function", target=torch.mul, args=(), kwargs={}, name="op1"
)
node.meta["debug_handle"] = 1
debug_handle_to_op_name = get_aot_debug_handle_to_op_name_mapping(graph_module)
expected_result = {(1,): "op1"}
self.assertEqual(debug_handle_to_op_name, expected_result)

def test_get_aot_debug_handle_to_op_name_mapping_multiple_debug_handles(self):
# Create a simple graph module with two nodes
graph_module = torch.fx.GraphModule({}, torch.fx.Graph())
node1 = graph_module.graph.create_node(
"call_function", target=torch.mul, args=(), kwargs={}, name="op1"
)
node1.meta["debug_handle"] = (1, 2)
node2 = graph_module.graph.create_node(
"call_function", target=torch.mul, args=(), kwargs={}, name="op2"
)
node2.meta["debug_handle"] = 3
debug_handle_to_op_name = get_aot_debug_handle_to_op_name_mapping(graph_module)
expected_result = {
(
1,
2,
): "op1",
(3,): "op2",
}
self.assertEqual(debug_handle_to_op_name, expected_result)

def test_get_aot_debug_handle_to_op_name_mapping_no_debug_handles(self):
# Create a simple graph module with no nodes
graph_module = torch.fx.GraphModule({}, torch.fx.Graph())
debug_handle_to_op_name = get_aot_debug_handle_to_op_name_mapping(graph_module)
expected_result = {}
self.assertEqual(debug_handle_to_op_name, expected_result)

def test_node_filter_match(self):
node_filter = NodeFilter(
"debug_handle", "call_function", exclude_ops=["getitem"]
)

# Create a mock node that matches the filter criteria
mock_node = torch.fx.Node(
graph=torch.fx.Graph(),
name="mock_node",
op="call_function",
target=torch.nn.functional.relu,
args=(),
kwargs={},
)
mock_node.meta["debug_handle"] = (1, 2)
# Test that the filter matches the mock node
self.assertTrue(node_filter.matches(mock_node))

def test_node_filter_key_mismatch(self):
node_filter = NodeFilter(
"debug_handle", "call_function", exclude_ops=["getitem"]
)
mock_node_metadata_key_mismatch = torch.fx.Node(
graph=torch.fx.Graph(),
name="mock_node_metadata_key_mismatch",
op="call_function",
target=torch.nn.functional.relu,
args=(),
kwargs={},
)
# Test that the filter doesn't match the mock node (meta doesn't have debug_handle key)
self.assertFalse(node_filter.matches(mock_node_metadata_key_mismatch))

def test_node_filter_ops_mismatch(self):
node_filter = NodeFilter(
"debug_handle", "call_function", exclude_ops=["getitem"]
)

mock_node_exclude_ops_mismatch = torch.fx.Node(
graph=torch.fx.Graph(),
name="getitem",
op="call_function",
target=torch.nn.functional.relu,
args=(),
kwargs={},
)
mock_node_exclude_ops_mismatch.meta["debug_handle"] = (1, 2)
# Test that the filter doesn't match the mock node (exclude_ops mismatch)
self.assertFalse(node_filter.matches(mock_node_exclude_ops_mismatch))

def test_node_op_type_mismatch(self):
node_filter = NodeFilter(
"debug_handle", "call_function", exclude_ops=["getitem"]
)

mock_node_op_type_mismatch = torch.fx.Node(
graph=torch.fx.Graph(),
name="mock_node_op_type_mismatch",
op="get_attr",
target="torch.nn.functional.relu",
args=(),
kwargs={},
)
mock_node_op_type_mismatch.meta["debug_handle"] = (1, 2)
# Test that the filter doesn't match the mock node (op_type mismatch)
self.assertFalse(node_filter.matches(mock_node_op_type_mismatch))


def gen_mock_operator_graph_with_expected_map() -> (
Tuple[OperatorGraph, Dict[int, OperatorNode]]
Expand Down
Loading