Skip to content

Commit fe31c6c

Browse files
Support multi-module exports in Inspector and ETRecord
- Added optional module_name and method_name parameters to Inspector initialization - Updated _consume_etrecord method to handle multi-module export scenarios - Modified gen_graphs_from_etrecord to support custom graph keys (for each module/method) - Updated inspector_cli to accept module and method name arguments - Improved error handling for multi-module export use cases
1 parent d99970b commit fe31c6c

File tree

4 files changed

+39
-13
lines changed

4 files changed

+39
-13
lines changed

devtools/etrecord/_etrecord.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -232,9 +232,10 @@ def generate_etrecord(
232232
edge_dialect_program.exported_program,
233233
)
234234
else:
235-
raise RuntimeError(
236-
f"Unsupported type of edge_dialect_program passed in {type(edge_dialect_program)}."
237-
)
235+
if export_modules is None:
236+
raise RuntimeError(
237+
f"Unsupported type of edge_dialect_program passed in {type(edge_dialect_program)}."
238+
)
238239

239240
# When a BundledProgram is passed in, extract the reference outputs and save in a file
240241
if isinstance(executorch_program, BundledProgram):

devtools/inspector/_inspector.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -978,6 +978,8 @@ def __init__(
978978
Callable[[Union[int, str], Union[int, float]], Union[int, float]]
979979
] = None,
980980
enable_module_hierarchy: bool = False,
981+
module_name: Optional[str] = None,
982+
method_name: Optional[str] = None,
981983
) -> None:
982984
r"""
983985
Initialize an `Inspector` instance with the underlying `EventBlock`\ s populated with data from the provided ETDump path or binary,
@@ -995,6 +997,8 @@ def __init__(
995997
delegate_time_scale_converter: Optional function to convert the time scale of delegate profiling data. If not given, use the conversion ratio of
996998
target_time_scale/source_time_scale.
997999
enable_module_hierarchy: Enable submodules in the operator graph. Defaults to False.
1000+
module_name: Optional module name to inspect (used with multi-module exports).
1001+
method_name: Optional method name to inspect (used with multi-module exports).
9981002
9991003
Returns:
10001004
None
@@ -1059,9 +1063,9 @@ def __init__(
10591063
# Key str is method name; value is list of ProgramOutputs because of list of test cases
10601064
self._reference_outputs: Dict[str, List[ProgramOutput]] = {}
10611065
self._enable_module_hierarchy = enable_module_hierarchy
1062-
self._consume_etrecord()
1066+
self._consume_etrecord(module_name, method_name)
10631067

1064-
def _consume_etrecord(self) -> None:
1068+
def _consume_etrecord(self, module_name: Optional[str] = None, method_name: Optional[str] = None,) -> None:
10651069
"""
10661070
If an ETRecord is provided, connect it to the EventBlocks and populate the Event metadata.
10671071
@@ -1081,15 +1085,21 @@ def _consume_etrecord(self) -> None:
10811085
bundled_input_index of the EventBlock.
10821086
"""
10831087

1084-
if self._etrecord is None:
1085-
return
1088+
if method_name is None and module_name is None:
1089+
method_name = FORWARD
1090+
edge_dialect_graph_key = EDGE_DIALECT_GRAPH_KEY
1091+
elif method_name is None or module_name is None:
1092+
raise ValueError("Either both method_name and module_name should be provided or neither should be provided")
1093+
else:
1094+
method_name = method_name
1095+
edge_dialect_graph_key = f"{module_name}/{method_name}"
10861096

10871097
# (1) Debug Handle Symbolification
10881098
for event_block in self.event_blocks:
10891099
event_block._gen_resolve_debug_handles(
1090-
self._etrecord._debug_handle_map[FORWARD],
1100+
self._etrecord._debug_handle_map[method_name],
10911101
(
1092-
self._etrecord._delegate_map[FORWARD]
1102+
self._etrecord._delegate_map[method_name]
10931103
if self._etrecord._delegate_map is not None
10941104
else None
10951105
),
@@ -1099,9 +1109,10 @@ def _consume_etrecord(self) -> None:
10991109
self.op_graph_dict = gen_graphs_from_etrecord(
11001110
etrecord=self._etrecord,
11011111
enable_module_hierarchy=self._enable_module_hierarchy,
1112+
edge_dialect_graph_key=edge_dialect_graph_key,
11021113
)
11031114
debug_handle_to_op_node_map = create_debug_handle_to_op_node_mapping(
1104-
self.op_graph_dict[EDGE_DIALECT_GRAPH_KEY],
1115+
self.op_graph_dict[edge_dialect_graph_key],
11051116
)
11061117
for event_block in self.event_blocks:
11071118
for event in event_block.events:
@@ -1116,7 +1127,7 @@ def _consume_etrecord(self) -> None:
11161127
for event_block in self.event_blocks:
11171128
index = event_block.bundled_input_index
11181129
if index is not None:
1119-
event_block.reference_output = self._reference_outputs[FORWARD][
1130+
event_block.reference_output = self._reference_outputs[method_name][
11201131
index
11211132
]
11221133

devtools/inspector/_inspector_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def is_debug_output(value: Value) -> bool:
236236

237237

238238
def gen_graphs_from_etrecord(
239-
etrecord: ETRecord, enable_module_hierarchy: bool = False
239+
etrecord: ETRecord, enable_module_hierarchy: bool = False, edge_dialect_graph_key: str = EDGE_DIALECT_GRAPH_KEY
240240
) -> Mapping[str, OperatorGraph]:
241241
op_graph_map = {}
242242
if etrecord.graph_map is not None:
@@ -248,7 +248,7 @@ def gen_graphs_from_etrecord(
248248
for name, exported_program in etrecord.graph_map.items()
249249
}
250250
if etrecord.edge_dialect_program is not None:
251-
op_graph_map[EDGE_DIALECT_GRAPH_KEY] = FXOperatorGraph.gen_operator_graph(
251+
op_graph_map[edge_dialect_graph_key] = FXOperatorGraph.gen_operator_graph(
252252
etrecord.edge_dialect_program.graph_module,
253253
enable_module_hierarchy=enable_module_hierarchy,
254254
)

devtools/inspector/inspector_cli.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,18 @@ def main() -> None:
4848
required=False,
4949
help="Provide an optional tsv file path.",
5050
)
51+
parser.add_argument(
52+
"--method_name",
53+
required=False,
54+
default=None,
55+
help="Method Name to inspect (used with multi-module exports)",
56+
)
57+
parser.add_argument(
58+
"--module_name",
59+
required=False,
60+
default=None,
61+
help="Module Name to inspect (used with multi-module exports)",
62+
)
5163
parser.add_argument("--compare_results", action="store_true")
5264

5365
args = parser.parse_args()
@@ -58,6 +70,8 @@ def main() -> None:
5870
debug_buffer_path=args.debug_buffer_path,
5971
source_time_scale=TimeScale(args.source_time_scale),
6072
target_time_scale=TimeScale(args.target_time_scale),
73+
module_name=args.module_name,
74+
method_name=args.method_name,
6175
)
6276
inspector.print_data_tabular()
6377
if args.tsv_path:

0 commit comments

Comments
 (0)