diff --git a/devtools/etrecord/_etrecord.py b/devtools/etrecord/_etrecord.py index de7cf93990a..c6decb8adf1 100644 --- a/devtools/etrecord/_etrecord.py +++ b/devtools/etrecord/_etrecord.py @@ -232,9 +232,10 @@ def generate_etrecord( edge_dialect_program.exported_program, ) else: - raise RuntimeError( - f"Unsupported type of edge_dialect_program passed in {type(edge_dialect_program)}." - ) + if export_modules is None: + raise RuntimeError( + f"Unsupported type of edge_dialect_program passed in {type(edge_dialect_program)}." + ) # When a BundledProgram is passed in, extract the reference outputs and save in a file if isinstance(executorch_program, BundledProgram): diff --git a/devtools/inspector/_inspector.py b/devtools/inspector/_inspector.py index 989663601f8..6eba89d45c8 100644 --- a/devtools/inspector/_inspector.py +++ b/devtools/inspector/_inspector.py @@ -978,6 +978,8 @@ def __init__( Callable[[Union[int, str], Union[int, float]], Union[int, float]] ] = None, enable_module_hierarchy: bool = False, + module_name: Optional[str] = None, + method_name: Optional[str] = None, ) -> None: r""" Initialize an `Inspector` instance with the underlying `EventBlock`\ s populated with data from the provided ETDump path or binary, @@ -995,6 +997,8 @@ def __init__( delegate_time_scale_converter: Optional function to convert the time scale of delegate profiling data. If not given, use the conversion ratio of target_time_scale/source_time_scale. enable_module_hierarchy: Enable submodules in the operator graph. Defaults to False. + module_name: Optional module name to inspect (used with multi-module exports). + method_name: Optional method name to inspect (used with multi-module exports). Returns: None @@ -1059,9 +1063,13 @@ def __init__( # Key str is method name; value is list of ProgramOutputs because of list of test cases self._reference_outputs: Dict[str, List[ProgramOutput]] = {} self._enable_module_hierarchy = enable_module_hierarchy - self._consume_etrecord() + self._consume_etrecord(module_name, method_name) - def _consume_etrecord(self) -> None: + def _consume_etrecord( + self, + module_name: Optional[str] = None, + method_name: Optional[str] = None, + ) -> None: """ If an ETRecord is provided, connect it to the EventBlocks and populate the Event metadata. @@ -1081,15 +1089,23 @@ def _consume_etrecord(self) -> None: bundled_input_index of the EventBlock. """ - if self._etrecord is None: - return + if method_name is None and module_name is None: + method_name = FORWARD + edge_dialect_graph_key = EDGE_DIALECT_GRAPH_KEY + elif method_name is None or module_name is None: + raise ValueError( + "Either both method_name and module_name should be provided or neither should be provided" + ) + else: + method_name = method_name + edge_dialect_graph_key = f"{module_name}/{method_name}" # (1) Debug Handle Symbolification for event_block in self.event_blocks: event_block._gen_resolve_debug_handles( - self._etrecord._debug_handle_map[FORWARD], + self._etrecord._debug_handle_map[method_name], ( - self._etrecord._delegate_map[FORWARD] + self._etrecord._delegate_map[method_name] if self._etrecord._delegate_map is not None else None ), @@ -1099,9 +1115,10 @@ def _consume_etrecord(self) -> None: self.op_graph_dict = gen_graphs_from_etrecord( etrecord=self._etrecord, enable_module_hierarchy=self._enable_module_hierarchy, + edge_dialect_graph_key=edge_dialect_graph_key, ) debug_handle_to_op_node_map = create_debug_handle_to_op_node_mapping( - self.op_graph_dict[EDGE_DIALECT_GRAPH_KEY], + self.op_graph_dict[edge_dialect_graph_key], ) for event_block in self.event_blocks: for event in event_block.events: @@ -1116,7 +1133,7 @@ def _consume_etrecord(self) -> None: for event_block in self.event_blocks: index = event_block.bundled_input_index if index is not None: - event_block.reference_output = self._reference_outputs[FORWARD][ + event_block.reference_output = self._reference_outputs[method_name][ index ] diff --git a/devtools/inspector/_inspector_utils.py b/devtools/inspector/_inspector_utils.py index 3faf62e008a..2931417ab2e 100644 --- a/devtools/inspector/_inspector_utils.py +++ b/devtools/inspector/_inspector_utils.py @@ -236,7 +236,9 @@ def is_debug_output(value: Value) -> bool: def gen_graphs_from_etrecord( - etrecord: ETRecord, enable_module_hierarchy: bool = False + etrecord: ETRecord, + enable_module_hierarchy: bool = False, + edge_dialect_graph_key: str = EDGE_DIALECT_GRAPH_KEY, ) -> Mapping[str, OperatorGraph]: op_graph_map = {} if etrecord.graph_map is not None: @@ -248,7 +250,7 @@ def gen_graphs_from_etrecord( for name, exported_program in etrecord.graph_map.items() } if etrecord.edge_dialect_program is not None: - op_graph_map[EDGE_DIALECT_GRAPH_KEY] = FXOperatorGraph.gen_operator_graph( + op_graph_map[edge_dialect_graph_key] = FXOperatorGraph.gen_operator_graph( etrecord.edge_dialect_program.graph_module, enable_module_hierarchy=enable_module_hierarchy, ) diff --git a/devtools/inspector/inspector_cli.py b/devtools/inspector/inspector_cli.py index 00e74cc25f8..17bbf74acfb 100644 --- a/devtools/inspector/inspector_cli.py +++ b/devtools/inspector/inspector_cli.py @@ -48,6 +48,18 @@ def main() -> None: required=False, help="Provide an optional tsv file path.", ) + parser.add_argument( + "--method_name", + required=False, + default=None, + help="Method Name to inspect (used with multi-module exports)", + ) + parser.add_argument( + "--module_name", + required=False, + default=None, + help="Module Name to inspect (used with multi-module exports)", + ) parser.add_argument("--compare_results", action="store_true") args = parser.parse_args() @@ -58,6 +70,8 @@ def main() -> None: debug_buffer_path=args.debug_buffer_path, source_time_scale=TimeScale(args.source_time_scale), target_time_scale=TimeScale(args.target_time_scale), + module_name=args.module_name, + method_name=args.method_name, ) inspector.print_data_tabular() if args.tsv_path: