diff --git a/devtools/inspector/_inspector.py b/devtools/inspector/_inspector.py index c5e4bbc9a06..17a7451aadf 100644 --- a/devtools/inspector/_inspector.py +++ b/devtools/inspector/_inspector.py @@ -62,6 +62,7 @@ map_runtime_aot_intermediate_outputs, merge_runtime_overlapping_debug_handles, ProgramOutput, + propagate_back_debug_handle, RESERVED_FRAMEWORK_EVENT_NAMES, TimeScale, verify_debug_data_equivalence, @@ -1166,7 +1167,18 @@ def _get_aot_intermediate_outputs_and_op_names( """ if self._etrecord._representative_inputs is None: return {}, {} - export_program = self._etrecord.edge_dialect_program + + export_program = None + + # Will use the exported program to extract intermediate output if and only if exported_program has been provided, and it is the greatest ancestor of the edge_dialect_program + if self._etrecord.exported_program and propagate_back_debug_handle( + self._etrecord.exported_program, + self._etrecord.export_graph_id, + self._etrecord.edge_dialect_program, + ): + export_program = self._etrecord.exported_program + else: + export_program = self._etrecord.edge_dialect_program graph_module = export_program.module() aot_debug_handle_to_op_name = get_aot_debug_handle_to_op_name_mapping( graph_module diff --git a/devtools/inspector/tests/inspector_test.py b/devtools/inspector/tests/inspector_test.py index c36311afeab..37dc7921923 100644 --- a/devtools/inspector/tests/inspector_test.py +++ b/devtools/inspector/tests/inspector_test.py @@ -25,7 +25,6 @@ from executorch.devtools import generate_etrecord, parse_etrecord from executorch.devtools.debug_format.et_schema import OperatorNode from executorch.devtools.etdump.schema_flatcc import ProfileEvent -from executorch.devtools.etrecord._etrecord import ETRecord from executorch.devtools.etrecord.tests.etrecord_test import TestETRecord from executorch.devtools.inspector import ( @@ -480,7 +479,7 @@ def test_populate_debugging_related_fields_passes_for_consistent_events(self): events=events, ) - def test_etrecord_populates_correct_aot_intermediate_outputs(self): + def test_etrecord_populates_correct_edge_dialect_aot_intermediate_outputs(self): with tempfile.NamedTemporaryFile(suffix=".bin") as tmp_file: etrecord_path = tmp_file.name mod = model_registry["ConvLinearModel"]() @@ -513,15 +512,11 @@ def test_etrecord_populates_correct_aot_intermediate_outputs(self): etdump_path=ETDUMP_PATH, etrecord=etrecord_path, ) - etrecord = ETRecord( - edge_dialect_program=inspector_instance._etrecord.edge_dialect_program, - graph_map=inspector_instance._etrecord.graph_map, - _debug_handle_map=inspector_instance._etrecord._debug_handle_map, - _delegate_map=inspector_instance._etrecord._delegate_map, - _reference_outputs=inspector_instance._etrecord._reference_outputs, - _representative_inputs=aten_model.example_inputs[0], + + inspector_instance._etrecord._representative_inputs = ( + aten_model.example_inputs[0] ) - inspector_instance._etrecord = etrecord + aot_intermediate_outputs, aot_debug_handle_to_op_names = ( inspector_instance._get_aot_intermediate_outputs_and_op_names() ) @@ -534,7 +529,61 @@ def test_etrecord_populates_correct_aot_intermediate_outputs(self): self.assertTrue( check_if_debug_handle_to_op_names_match( - "ConvLinearModel", aot_debug_handle_to_op_names + aot_debug_handle_to_op_names, + mod.get_edge_dialect_expected_debug_handle_to_op_names(), + ) + ) + + def test_etrecord_populates_correct_export_program_aot_intermediate_outputs(self): + with tempfile.NamedTemporaryFile(suffix=".bin") as tmp_file: + etrecord_path = tmp_file.name + mod = model_registry["ConvLinearModel"]() + input_tensor = mod.get_input() + aten_model: ExportedProgram = export(mod, (input_tensor,), strict=True) + edge_program_manager: EdgeProgramManager = to_edge(aten_model) + edge_program_manager_copy = copy.deepcopy(edge_program_manager) + et_program_manager: ExecutorchProgramManager = ( + edge_program_manager.to_executorch() + ) + # Generate ETRecord with the exported program + generate_etrecord( + etrecord_path, + edge_program_manager_copy, + et_program_manager, + exported_program=aten_model, + ) + with patch.object( + Inspector, "_consume_etrecord", return_value=None + ), patch.object( + _inspector, "gen_etdump_object", return_value=None + ), patch.object( + EventBlock, "_gen_from_etdump" + ), patch.object( + _inspector, "gen_graphs_from_etrecord" + ): + # Call the constructor of Inspector + inspector_instance = Inspector( + etdump_path=ETDUMP_PATH, + etrecord=etrecord_path, + ) + + inspector_instance._etrecord._representative_inputs = ( + aten_model.example_inputs[0] + ) + + aot_intermediate_outputs, aot_debug_handle_to_op_names = ( + inspector_instance._get_aot_intermediate_outputs_and_op_names() + ) + self.assertTrue( + check_if_intermediate_outputs_match( + aot_intermediate_outputs, + mod.get_exported_program_expected_intermediate_outputs(), + ) + ) + self.assertTrue( + check_if_debug_handle_to_op_names_match( + aot_debug_handle_to_op_names, + mod.get_exported_program_expected_debug_handle_to_op_names(), ) ) diff --git a/devtools/inspector/tests/inspector_test_utils.py b/devtools/inspector/tests/inspector_test_utils.py index da426377564..69c787608b1 100644 --- a/devtools/inspector/tests/inspector_test_utils.py +++ b/devtools/inspector/tests/inspector_test_utils.py @@ -79,7 +79,7 @@ def get_edge_dialect_expected_intermediate_outputs(): } @staticmethod - def get_expected_debug_handle_to_op_names(): + def get_edge_dialect_expected_debug_handle_to_op_names(): """ Returns the expected debug handle and op names mapping for this model for the given input. """ @@ -100,7 +100,7 @@ def get_expected_debug_handle_to_op_names(): @staticmethod def get_exported_program_expected_intermediate_outputs(): """ - Returns the expected outputs of the debug handles and intermediate output mapping for edge dialect graph of this model for the given input. + Returns the expected outputs of the debug handles and intermediate output mapping for export graph of this model for the given input. """ return { (UNSET_DEBUG_HANDLE,): torch.tensor([[5.4000, 13.5200]]), @@ -117,6 +117,26 @@ def get_exported_program_expected_intermediate_outputs(): (11,): [torch.tensor([[0.9734]]), torch.tensor([[0.9891]])], } + @staticmethod + def get_exported_program_expected_debug_handle_to_op_names(): + """ + Returns the expected debug handle and op name mapping for this model for the given input. + """ + return { + (UNSET_DEBUG_HANDLE,): ["_assert_tensor_metadata_default", "to"], + (1,): ["conv2d"], + (2,): ["view"], + (3,): ["linear"], + (4,): ["add"], + (5,): ["sub"], + (6,): ["mul"], + (7,): ["add_1"], + (8,): ["div"], + (9,): ["relu"], + (10,): ["sigmoid"], + (11,): ["split"], + } + # Global model registry model_registry = { @@ -153,15 +173,13 @@ def check_if_intermediate_outputs_match( return True -def check_if_debug_handle_to_op_names_match(model_name, actual_debug_handle_to_op_name): +def check_if_debug_handle_to_op_names_match( + actual_debug_handle_to_op_name, expected_debug_handle_to_op_name +): """ Checks if the actual op names match the expected op names for the specified model. Returns True if all match, otherwise returns False. """ - model_instance = model_registry[model_name] - expected_debug_handle_to_op_name = ( - model_instance.get_expected_debug_handle_to_op_names() - ) if len(actual_debug_handle_to_op_name) != len(expected_debug_handle_to_op_name): return False for debug_handle, expected_op_name in expected_debug_handle_to_op_name.items():