Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions devtools/inspector/_inspector.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from executorch.devtools.inspector._inspector_utils import (
calculate_time_scale_factor,
create_debug_handle_to_op_node_mapping,
DebugHandle,
display_or_print_df,
EDGE_DIALECT_GRAPH_KEY,
EXCLUDED_COLUMNS_WHEN_PRINTING,
Expand Down Expand Up @@ -262,7 +263,7 @@ class RunSignature:

# Typing for mapping Event.delegate_debug_identifiers to debug_handle(s)
DelegateIdentifierDebugHandleMap: TypeAlias = Union[
Mapping[int, Tuple[int, ...]], Mapping[str, Tuple[int, ...]]
Mapping[int, DebugHandle], Mapping[str, DebugHandle]
]

# Typing for Dict containig delegate metadata
Expand Down Expand Up @@ -1149,7 +1150,7 @@ def _consume_etrecord(self) -> None:

def _get_aot_intermediate_outputs_and_op_names(
self,
) -> Tuple[Dict[Tuple[int, ...], Any], Dict[Tuple[int, ...], str]]:
) -> Tuple[Dict[DebugHandle, Any], Dict[DebugHandle, str]]:
"""
Capture intermediate outputs only if _representative_inputs are provided
when using bundled program to create the etrecord
Expand All @@ -1170,7 +1171,7 @@ def _get_aot_intermediate_outputs_and_op_names(
# TODO: Make it more extensible to further merge overlapping debug handles
def _get_runtime_intermediate_outputs_and_op_names(
self,
) -> Tuple[Dict[Tuple[int, ...], Any], Dict[Tuple[int, ...], str]]:
) -> Tuple[Dict[DebugHandle, Any], Dict[DebugHandle, str]]:
"""
Retrieve the runtime intermediate outputs(debug handles and intermediate values mappings)
from the event blocks, along with the corresponding debug handles and op names mapping.
Expand Down
32 changes: 17 additions & 15 deletions devtools/inspector/_inspector_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ class TimeScale(Enum):
TimeScale.CYCLES: 1,
}

DebugHandle: TypeAlias = Tuple[int, ...]


class NodeSource(Enum):
AOT = 1
Expand Down Expand Up @@ -528,7 +530,7 @@ def compare_results(
return results


def merge_overlapping_debug_handles(intermediate_outputs: Dict[Tuple[int, ...], Any]):
def merge_overlapping_debug_handles(intermediate_outputs: Dict[DebugHandle, Any]):
"""
Merge overlapping debug handles int a single key
"""
Expand Down Expand Up @@ -558,7 +560,7 @@ def merge_overlapping_debug_handles(intermediate_outputs: Dict[Tuple[int, ...],


def _debug_handles_have_overlap(
aot_debug_hanlde: Tuple[int, ...], runtime_debug_handle: Tuple[int, ...]
aot_debug_hanlde: DebugHandle, runtime_debug_handle: DebugHandle
) -> bool:
"""
Check if the AOT debug handle and the runtime debug handle have any overlap.
Expand All @@ -568,7 +570,7 @@ def _debug_handles_have_overlap(
return len(aot_set.intersection(runtime_set)) > 0


def _combine_debug_hanldes(debug_handles: List[Tuple[int, ...]]) -> Tuple[int, ...]:
def _combine_debug_hanldes(debug_handles: List[DebugHandle]) -> DebugHandle:
"""Combine multiple debug handles into one debug handle"""
combined_debug_handles_set = set()
for debug_handle in debug_handles:
Expand All @@ -577,8 +579,8 @@ def _combine_debug_hanldes(debug_handles: List[Tuple[int, ...]]) -> Tuple[int, .


def _combine_overlapped_intermediate_outputs(
nodes: List[Tuple[Tuple[int, ...], Any]]
) -> Tuple[Tuple[int, ...], Any]:
nodes: List[Tuple[DebugHandle, Any]]
) -> Tuple[DebugHandle, Any]:
"""Combine multiple overlapped intermediate outputs into one with combined debug_handles and last output"""
debug_handles = [debug_handle for debug_handle, _ in nodes]
outputs = [output for _, output in nodes]
Expand All @@ -588,8 +590,8 @@ def _combine_overlapped_intermediate_outputs(


def _create_debug_handle_overlap_graph(
aot_intermediate_outputs: Dict[Tuple[int, ...], Any],
runtime_intermediate_outputs: Dict[Tuple[int, ...], Any],
aot_intermediate_outputs: Dict[DebugHandle, Any],
runtime_intermediate_outputs: Dict[DebugHandle, Any],
) -> Tuple[List[NodeData], Dict[int, List[int]]]:
"""
Create a graph representing overlapping debug handles between AOT and runtime outputs.
Expand Down Expand Up @@ -659,15 +661,15 @@ def dfs(node_id, component):


def map_runtime_aot_intermediate_outputs(
aot_intermediate_outputs: Dict[Tuple[int, ...], Any],
runtime_intermediate_outputs: Dict[Tuple[int, ...], Any],
) -> Dict[Tuple[Tuple[int, ...], Any], Tuple[Tuple[int, ...], Any]]:
aot_intermediate_outputs: Dict[DebugHandle, Any],
runtime_intermediate_outputs: Dict[DebugHandle, Any],
) -> Dict[Tuple[DebugHandle, Any], Tuple[DebugHandle, Any]]:
"""
Map the runtime intermediate outputs to the AOT intermediate outputs
by finding overlapping debug handles and combining them into a single debug_handle

Returns:
Dict[Tuple[Tuple[int, ...], Any], Tuple[Tuple[int, ...], Any]] - Mapping
Dict[Tuple[DebugHandle, Any], Tuple[DebugHandle, Any]] - Mapping
from runtime intermediate output to AOT intermediate output
"""
# Merge overlapping debug handles
Expand Down Expand Up @@ -760,13 +762,13 @@ def convert_to_float_tensor(input_data: Any) -> torch.Tensor:

def get_aot_debug_handle_to_op_name_mapping(
graph_module: torch.fx.GraphModule,
) -> Dict[Tuple[int, ...], str]:
) -> Dict[DebugHandle, str]:
"""
Get a mapping from debug handle to operator name from the ETRecord edge_dialect_program's graph module.
Parameters:
graph_module (torch.fx.GraphModule): The graph module to get the mapping from.
Returns:
Dict[Tuple[int, ...], str]: A dictionary mapping debug handles to operator names.
Dict[DebugHandle, str]: A dictionary mapping debug handles to operator names.
"""
node_filters = [
NodeFilter("debug_handle", "call_function", exclude_ops=["getitem"])
Expand All @@ -787,8 +789,8 @@ def get_aot_debug_handle_to_op_name_mapping(


def find_op_names(
target_debug_handle: Tuple[int, ...],
debug_handle_to_op_name: Dict[Tuple[int, ...], str],
target_debug_handle: DebugHandle,
debug_handle_to_op_name: Dict[DebugHandle, str],
) -> List[str]:
"""
Record the operator names only if their debug handles are part of the target debug handle.
Expand Down
6 changes: 3 additions & 3 deletions devtools/inspector/_intermediate_output_capturer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
# pyre-unsafe


from typing import Any, Dict, Tuple
from typing import Any, Dict

import torch
from executorch.devtools.inspector._inspector_utils import NodeFilter
from executorch.devtools.inspector._inspector_utils import DebugHandle, NodeFilter
from torch.fx import GraphModule
from torch.fx.interpreter import Interpreter

Expand All @@ -30,7 +30,7 @@ def __init__(self, module: GraphModule):
]

# Runs the graph module and captures the intermediate outputs.
def run_and_capture(self, *args, **kwargs) -> Dict[Tuple[int, ...], Any]:
def run_and_capture(self, *args, **kwargs) -> Dict[DebugHandle, Any]:
captured_outputs = {}

def capture_run_node(n: torch.fx.Node) -> Any:
Expand Down
Loading