diff --git a/exir/backend/backend_api.py b/exir/backend/backend_api.py index 838156498c4..d8210e7433a 100644 --- a/exir/backend/backend_api.py +++ b/exir/backend/backend_api.py @@ -238,7 +238,6 @@ def generate_debug_handle(ep: ExportedProgram) -> int: call_delegate_node.meta["val"] = submodule_output_node.meta["val"] call_submodule_node.replace_all_uses_with(call_delegate_node) owning_graph_module.graph.erase_node(call_submodule_node) - if is_submodule: assert len(toplevel_input_specs_to_delete) == 0 assert len(toplevel_output_specs_to_delete) == 0 @@ -574,26 +573,29 @@ def lower_all_submodules_to_backend( # The created exported program for the submodules are in the call_module node's meta data # We just map the method_to_submodule_nodes directly to the method_to_partitioned_exported_programs method_to_partitioned_program = { - method_name: [node.meta["submodule_program"] for node in call_submodule_nodes] + method_name: [ + copy.deepcopy(node.meta["submodule_program"]) + for node in call_submodule_nodes + ] for method_name, call_submodule_nodes in method_to_submodules_nodes.items() } method_to_compile_specs = { method_name: [node.meta["compile_spec"] for node in call_submodule_nodes] for method_name, call_submodule_nodes in method_to_submodules_nodes.items() } - backend_found = False - for cls in BackendDetails.__subclasses__(): - if backend_id == cls.__name__: - method_to_preprocess_result: dict[str, List[PreprocessResult]] = ( - cls.preprocess_multimethod( - method_to_partitioned_program, method_to_compile_specs - ) - ) - backend_found = True - if not backend_found: + backend_name_to_subclass = { + subclass.__name__: subclass for subclass in BackendDetails.__subclasses__() + } + if backend_id not in backend_name_to_subclass: raise NotImplementedError(f"Backend {backend_id} was not found.") + method_to_preprocess_result: dict[str, List[PreprocessResult]] = ( + backend_name_to_subclass[backend_id].preprocess_multimethod( + method_to_partitioned_program, method_to_compile_specs + ) + ) + for method_name in method_to_preprocess_result.keys(): owning_program = method_to_tagged_edge_program[method_name] list_of_preprocess_results = method_to_preprocess_result[method_name] @@ -612,6 +614,9 @@ def lower_all_submodules_to_backend( compile_specs=compile_spec, named_data_store_output=preprocess_result.data_store_output, ) + lowered_module.meta = { + "debug_handle_map": preprocess_result.debug_handle_map, + } is_submodule = call_submodule_node.meta["is_submodule"] toplevel_input_specs_to_delete = call_submodule_node.meta[ "toplevel_input_specs_to_delete" @@ -633,6 +638,20 @@ def lower_all_submodules_to_backend( ) +def remove_used_metadata(graph: torch.fx.Graph) -> None: + """ + Remove the used metadata from the graph. + """ + for node in graph.nodes: + node.meta.pop("delegation_tag", None) + node.meta.pop("backend_id", None) + node.meta.pop("submodule_program", None) + node.meta.pop("toplevel_input_specs_to_delete", None) + node.meta.pop("toplevel_output_specs_to_delete", None) + node.meta.pop("is_submodule", None) + node.meta.pop("submodule_output_node", None) + + @dataclass class MethodProgramsPartitionerSpec: """ @@ -748,6 +767,7 @@ def to_backend( if method_name in method_to_tagged_exported_program: tagged_exported_program = method_to_tagged_exported_program[method_name] tagged_exported_program._validate() + remove_used_metadata(tagged_exported_program.graph_module.graph) partitioned_and_lowered_exported_programs[method_name] = ExportedProgram( root=tagged_exported_program.graph_module, graph=tagged_exported_program.graph_module.graph, diff --git a/exir/program/_program.py b/exir/program/_program.py index f10433c42ae..f24807e253d 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -23,7 +23,10 @@ from executorch.exir._serialize._serialize import serialize_for_executorch from executorch.exir._serialize.data_serializer import DataSerializer from executorch.exir._warnings import experimental -from executorch.exir.backend.backend_api import to_backend +from executorch.exir.backend.backend_api import ( + MethodProgramsPartitionerSpec, + to_backend, +) from executorch.exir.backend.partitioner import Partitioner from executorch.exir.capture._config import EdgeCompileConfig, ExecutorchBackendConfig from executorch.exir.delegate import executorch_call_delegate, is_lowered_module @@ -1239,10 +1242,16 @@ def to_edge_transform_and_lower( if transform_passes is not None: edge_manager = edge_manager.transform(transform_passes) - if partitioner is not None: + max_num_partitioners = 0 + for partitioner_list in partitioner.values(): + max_num_partitioners = max(max_num_partitioners, len(partitioner_list)) + + for i in range(max_num_partitioners): + method_to_partitioner = {} for name, partitioner_list in partitioner.items(): - for curr_partitioner in partitioner_list: - edge_manager = edge_manager.to_backend({name: curr_partitioner}) + if i < len(partitioner_list): + method_to_partitioner[name] = partitioner_list[i] + edge_manager = edge_manager.to_backend(method_to_partitioner) for name, program in edge_manager._edge_programs.items(): ops_set_to_not_decompose: Set[torch._ops.OpOverload] = set() @@ -1475,7 +1484,8 @@ def transform( @et_logger("to_backend") def to_backend( - self, partitioner: Union[Partitioner, Dict[str, Partitioner]] + self, + partitioner: Union[Partitioner, Dict[str, Partitioner]], ) -> "EdgeProgramManager": """ Returns a semantically-equivalent program to the one given as input, @@ -1501,17 +1511,18 @@ def to_backend( specified subgraphs lowered. """ new_edge_programs: Dict[str, ExportedProgram] = {} - if isinstance(partitioner, dict): - for name, program in self._edge_programs.items(): - if name in partitioner.keys(): - new_edge_programs[name] = to_backend(program, partitioner[name]) - else: - new_edge_programs[name] = program + method_to_partitioner: Dict[str, Partitioner] = {} + if not isinstance(partitioner, dict): + method_to_partitioner = {name: partitioner for name in self._edge_programs} + else: + method_to_partitioner = partitioner - else: # apply partitioner to every method - for name, program in self._edge_programs.items(): - new_edge_programs[name] = to_backend(program, partitioner) + method_to_programs_and_partitioners = MethodProgramsPartitionerSpec( + self._edge_programs, + method_to_partitioner, + ) + new_edge_programs = to_backend(method_to_programs_and_partitioners) config = EdgeCompileConfig(_check_ir_validity=False) return EdgeProgramManager( new_edge_programs,