diff --git a/exir/backend/test/test_delegate_map_builder.py b/exir/backend/test/test_delegate_map_builder.py index b93262d2dc6..827cb8cdebc 100644 --- a/exir/backend/test/test_delegate_map_builder.py +++ b/exir/backend/test/test_delegate_map_builder.py @@ -45,19 +45,19 @@ def forward(self, x): def test_basic_generated_identifier(self): delegate_builder = DelegateMappingBuilder(generated_identifiers=True) - expected_mapping = {0: (0, 1, 2, 3)} + expected_mapping = {0: (1, 2, 3, 4)} self.assertEqual( delegate_builder.insert_delegate_mapping_entry(nodes=self.nodes), 0 ) self.assertEqual(delegate_builder.get_delegate_mapping(), expected_mapping) - expected_mapping = {0: (0, 1, 2, 3), 1: (0,)} + expected_mapping = {0: (1, 2, 3, 4), 1: (1,)} self.assertEqual( delegate_builder.insert_delegate_mapping_entry(nodes=self.nodes[0]), 1 ) self.assertEqual(delegate_builder.get_delegate_mapping(), expected_mapping) - expected_mapping = {0: (0, 1, 2, 3), 1: (0,), 2: (1,)} + expected_mapping = {0: (1, 2, 3, 4), 1: (1,), 2: (2,)} self.assertEqual( delegate_builder.insert_delegate_mapping_entry(handles=self.handles[2]), 2, @@ -65,10 +65,10 @@ def test_basic_generated_identifier(self): self.assertEqual(delegate_builder.get_delegate_mapping(), expected_mapping) expected_mapping = { - 0: (0, 1, 2, 3), - 1: (0,), - 2: (1,), - 3: (0, 1, 2, 3), + 0: (1, 2, 3, 4), + 1: (1,), + 2: (2,), + 3: (1, 2, 3, 4), } self.assertEqual( delegate_builder.insert_delegate_mapping_entry(handles=self.handles), 3 @@ -144,7 +144,7 @@ def test_backend_with_delegate_mapping(self) -> None: self.assertEqual(len(debug_handle_map), 5) # Check to see that all the delegate debug indexes in the range [0,2] are present. self.assertTrue( - all(element in debug_handle_map.keys() for element in [0, 1, 2, 3]) + all(element in debug_handle_map.keys() for element in [1, 2, 3, 4]) ) class CompositeModule(torch.nn.Module): @@ -200,7 +200,7 @@ def _test_basic_manual_identifier(self, identifiers: Iterator[Union[int, str]]): # Entry with a list of nodes iden_1 = next(identifiers) - expected_mapping = {iden_1: (0, 1, 2, 3)} + expected_mapping = {iden_1: (1, 2, 3, 4)} self.assertEqual( delegate_builder_nodes.insert_delegate_mapping_entry( nodes=self.nodes, identifier=iden_1 @@ -222,7 +222,7 @@ def _test_basic_manual_identifier(self, identifiers: Iterator[Union[int, str]]): # Entry with a single node iden_2 = next(identifiers) - expected_mapping = {iden_1: (0, 1, 2, 3), iden_2: (0,)} + expected_mapping = {iden_1: (1, 2, 3, 4), iden_2: (1,)} self.assertEqual( delegate_builder_nodes.insert_delegate_mapping_entry( nodes=self.nodes[0], identifier=iden_2 diff --git a/exir/passes/debug_handle_generator_pass.py b/exir/passes/debug_handle_generator_pass.py index 44c86550559..bf2b99da9da 100644 --- a/exir/passes/debug_handle_generator_pass.py +++ b/exir/passes/debug_handle_generator_pass.py @@ -6,6 +6,7 @@ from executorch.exir.graph_module import get_control_flow_submodules from executorch.exir.pass_base import ExportPass +from torch.export import ExportedProgram from torch.fx import GraphModule from torch.fx.passes.infra.pass_base import PassResult @@ -17,7 +18,7 @@ def call(self, graph_module: GraphModule) -> PassResult: """ queue = [graph_module] - index = 0 + index = 1 # bfs to traverse all modules including control flow submodules to attached debug handle id while queue: current_graph_module = queue.pop(0) @@ -30,3 +31,35 @@ def call(self, graph_module: GraphModule) -> PassResult: ] queue.extend(control_flow_submodules) return PassResult(graph_module, True) + + +def generate_missing_debug_handles(ep: ExportedProgram): + """ + This pass is used to generate missing debug handles for the graph module and its submodules. + """ + + def get_control_flow_submodules_list(graph_module): + return [ + submodule for _, submodule, _ in get_control_flow_submodules(graph_module) + ] + + max_handle = 0 + queue = [ep.graph_module] + + while queue: + current_graph_module = queue.pop(0) + for node in current_graph_module.graph.nodes: + if "debug_handle" in node.meta: + max_handle = max(max_handle, node.meta["debug_handle"]) + control_flow_submodules = get_control_flow_submodules_list(ep.graph_module) + queue.extend(control_flow_submodules) + + queue = [ep.graph_module] + while queue: + current_graph_module = queue.pop(0) + for node in current_graph_module.graph.nodes: + if node.meta.get("debug_handle", 0) in (0, None): + node.meta["debug_handle"] = max_handle + 1 + max_handle += 1 + control_flow_submodules = get_control_flow_submodules_list(ep.graph_module) + queue.extend(control_flow_submodules) diff --git a/exir/tests/test_passes.py b/exir/tests/test_passes.py index 79578763475..ae5a1adaeb2 100644 --- a/exir/tests/test_passes.py +++ b/exir/tests/test_passes.py @@ -33,7 +33,10 @@ ToOutVarPass, ) from executorch.exir.passes.constant_prop_pass import constant_prop_pass -from executorch.exir.passes.debug_handle_generator_pass import DebugHandleGeneratorPass +from executorch.exir.passes.debug_handle_generator_pass import ( + DebugHandleGeneratorPass, + generate_missing_debug_handles, +) from executorch.exir.passes.insert_write_back_for_buffers_pass import ( insert_write_back_for_buffers_pass, ) @@ -949,13 +952,28 @@ def test_debug_handle_generator_pass(self) -> None: .exported_program() .graph_module ) - DebugHandleGeneratorPass()(graph_module) for node in graph_module.graph.nodes: self.assertIn("debug_handle", node.meta) ScalarToTensorPass()(graph_module) for node in graph_module.graph.nodes: self.assertIn("debug_handle", node.meta) + def test_generate_missing_debug_handles(self) -> None: + eager_model = MLP(2, output_size=4) + inputs = eager_model.get_random_inputs() + + ep = to_edge( + export( + eager_model, + inputs, + ) + ).exported_program() + + list(ep.graph.nodes)[0].meta.pop("debug_handle") + self.assertTrue(list(ep.graph.nodes)[0].meta.get("debug_handle") is None) + generate_missing_debug_handles(ep) + self.assertTrue(list(ep.graph.nodes)[0].meta.get("debug_handle") is not None) + def test_debug_handle_generator_pass_with_control_flow(self) -> None: def true_nested(y: torch.Tensor) -> torch.Tensor: y = y + y