diff --git a/devtools/inspector/_inspector.py b/devtools/inspector/_inspector.py index e4ddcce1ce7..c797208c0c9 100644 --- a/devtools/inspector/_inspector.py +++ b/devtools/inspector/_inspector.py @@ -43,6 +43,7 @@ from executorch.devtools.inspector._inspector_utils import ( calculate_time_scale_factor, create_debug_handle_to_op_node_mapping, + DebugHandle, display_or_print_df, EDGE_DIALECT_GRAPH_KEY, EXCLUDED_COLUMNS_WHEN_PRINTING, @@ -262,7 +263,7 @@ class RunSignature: # Typing for mapping Event.delegate_debug_identifiers to debug_handle(s) DelegateIdentifierDebugHandleMap: TypeAlias = Union[ - Mapping[int, Tuple[int, ...]], Mapping[str, Tuple[int, ...]] + Mapping[int, DebugHandle], Mapping[str, DebugHandle] ] # Typing for Dict containig delegate metadata @@ -1149,7 +1150,7 @@ def _consume_etrecord(self) -> None: def _get_aot_intermediate_outputs_and_op_names( self, - ) -> Tuple[Dict[Tuple[int, ...], Any], Dict[Tuple[int, ...], str]]: + ) -> Tuple[Dict[DebugHandle, Any], Dict[DebugHandle, str]]: """ Capture intermediate outputs only if _representative_inputs are provided when using bundled program to create the etrecord @@ -1170,7 +1171,7 @@ def _get_aot_intermediate_outputs_and_op_names( # TODO: Make it more extensible to further merge overlapping debug handles def _get_runtime_intermediate_outputs_and_op_names( self, - ) -> Tuple[Dict[Tuple[int, ...], Any], Dict[Tuple[int, ...], str]]: + ) -> Tuple[Dict[DebugHandle, Any], Dict[DebugHandle, str]]: """ Retrieve the runtime intermediate outputs(debug handles and intermediate values mappings) from the event blocks, along with the corresponding debug handles and op names mapping. diff --git a/devtools/inspector/_inspector_utils.py b/devtools/inspector/_inspector_utils.py index 6869c793946..249a2203e4c 100644 --- a/devtools/inspector/_inspector_utils.py +++ b/devtools/inspector/_inspector_utils.py @@ -73,6 +73,8 @@ class TimeScale(Enum): TimeScale.CYCLES: 1, } +DebugHandle: TypeAlias = Tuple[int, ...] + class NodeSource(Enum): AOT = 1 @@ -528,7 +530,7 @@ def compare_results( return results -def merge_overlapping_debug_handles(intermediate_outputs: Dict[Tuple[int, ...], Any]): +def merge_overlapping_debug_handles(intermediate_outputs: Dict[DebugHandle, Any]): """ Merge overlapping debug handles int a single key """ @@ -558,7 +560,7 @@ def merge_overlapping_debug_handles(intermediate_outputs: Dict[Tuple[int, ...], def _debug_handles_have_overlap( - aot_debug_hanlde: Tuple[int, ...], runtime_debug_handle: Tuple[int, ...] + aot_debug_hanlde: DebugHandle, runtime_debug_handle: DebugHandle ) -> bool: """ Check if the AOT debug handle and the runtime debug handle have any overlap. @@ -568,7 +570,7 @@ def _debug_handles_have_overlap( return len(aot_set.intersection(runtime_set)) > 0 -def _combine_debug_hanldes(debug_handles: List[Tuple[int, ...]]) -> Tuple[int, ...]: +def _combine_debug_hanldes(debug_handles: List[DebugHandle]) -> DebugHandle: """Combine multiple debug handles into one debug handle""" combined_debug_handles_set = set() for debug_handle in debug_handles: @@ -577,8 +579,8 @@ def _combine_debug_hanldes(debug_handles: List[Tuple[int, ...]]) -> Tuple[int, . def _combine_overlapped_intermediate_outputs( - nodes: List[Tuple[Tuple[int, ...], Any]] -) -> Tuple[Tuple[int, ...], Any]: + nodes: List[Tuple[DebugHandle, Any]] +) -> Tuple[DebugHandle, Any]: """Combine multiple overlapped intermediate outputs into one with combined debug_handles and last output""" debug_handles = [debug_handle for debug_handle, _ in nodes] outputs = [output for _, output in nodes] @@ -588,8 +590,8 @@ def _combine_overlapped_intermediate_outputs( def _create_debug_handle_overlap_graph( - aot_intermediate_outputs: Dict[Tuple[int, ...], Any], - runtime_intermediate_outputs: Dict[Tuple[int, ...], Any], + aot_intermediate_outputs: Dict[DebugHandle, Any], + runtime_intermediate_outputs: Dict[DebugHandle, Any], ) -> Tuple[List[NodeData], Dict[int, List[int]]]: """ Create a graph representing overlapping debug handles between AOT and runtime outputs. @@ -659,15 +661,15 @@ def dfs(node_id, component): def map_runtime_aot_intermediate_outputs( - aot_intermediate_outputs: Dict[Tuple[int, ...], Any], - runtime_intermediate_outputs: Dict[Tuple[int, ...], Any], -) -> Dict[Tuple[Tuple[int, ...], Any], Tuple[Tuple[int, ...], Any]]: + aot_intermediate_outputs: Dict[DebugHandle, Any], + runtime_intermediate_outputs: Dict[DebugHandle, Any], +) -> Dict[Tuple[DebugHandle, Any], Tuple[DebugHandle, Any]]: """ Map the runtime intermediate outputs to the AOT intermediate outputs by finding overlapping debug handles and combining them into a single debug_handle Returns: - Dict[Tuple[Tuple[int, ...], Any], Tuple[Tuple[int, ...], Any]] - Mapping + Dict[Tuple[DebugHandle, Any], Tuple[DebugHandle, Any]] - Mapping from runtime intermediate output to AOT intermediate output """ # Merge overlapping debug handles @@ -760,13 +762,13 @@ def convert_to_float_tensor(input_data: Any) -> torch.Tensor: def get_aot_debug_handle_to_op_name_mapping( graph_module: torch.fx.GraphModule, -) -> Dict[Tuple[int, ...], str]: +) -> Dict[DebugHandle, str]: """ Get a mapping from debug handle to operator name from the ETRecord edge_dialect_program's graph module. Parameters: graph_module (torch.fx.GraphModule): The graph module to get the mapping from. Returns: - Dict[Tuple[int, ...], str]: A dictionary mapping debug handles to operator names. + Dict[DebugHandle, str]: A dictionary mapping debug handles to operator names. """ node_filters = [ NodeFilter("debug_handle", "call_function", exclude_ops=["getitem"]) @@ -787,8 +789,8 @@ def get_aot_debug_handle_to_op_name_mapping( def find_op_names( - target_debug_handle: Tuple[int, ...], - debug_handle_to_op_name: Dict[Tuple[int, ...], str], + target_debug_handle: DebugHandle, + debug_handle_to_op_name: Dict[DebugHandle, str], ) -> List[str]: """ Record the operator names only if their debug handles are part of the target debug handle. diff --git a/devtools/inspector/_intermediate_output_capturer.py b/devtools/inspector/_intermediate_output_capturer.py index 054c97dc245..75737060fbf 100644 --- a/devtools/inspector/_intermediate_output_capturer.py +++ b/devtools/inspector/_intermediate_output_capturer.py @@ -7,10 +7,10 @@ # pyre-unsafe -from typing import Any, Dict, Tuple +from typing import Any, Dict import torch -from executorch.devtools.inspector._inspector_utils import NodeFilter +from executorch.devtools.inspector._inspector_utils import DebugHandle, NodeFilter from torch.fx import GraphModule from torch.fx.interpreter import Interpreter @@ -30,7 +30,7 @@ def __init__(self, module: GraphModule): ] # Runs the graph module and captures the intermediate outputs. - def run_and_capture(self, *args, **kwargs) -> Dict[Tuple[int, ...], Any]: + def run_and_capture(self, *args, **kwargs) -> Dict[DebugHandle, Any]: captured_outputs = {} def capture_run_node(n: torch.fx.Node) -> Any: