diff --git a/backends/qualcomm/tests/utils.py b/backends/qualcomm/tests/utils.py index 6e4d1c7502f..7209b0a2678 100644 --- a/backends/qualcomm/tests/utils.py +++ b/backends/qualcomm/tests/utils.py @@ -350,7 +350,6 @@ def lower_module_and_test_output( # Therefore, won't want to pre-allocate # by memory manager in runtime. memory_planning_pass=MemoryPlanningPass( - memory_planning_algo="greedy", alloc_graph_input=not self.shared_buffer, alloc_graph_output=not self.shared_buffer, ), diff --git a/backends/vulkan/vulkan_preprocess.py b/backends/vulkan/vulkan_preprocess.py index 1865c32acd7..7e85c25faee 100644 --- a/backends/vulkan/vulkan_preprocess.py +++ b/backends/vulkan/vulkan_preprocess.py @@ -57,7 +57,7 @@ def preprocess( # noqa: C901 MeanToSumDiv(), SpecPropPass(), ConstraintBasedSymShapeEvalPass(), - MemoryPlanningPass("greedy"), + MemoryPlanningPass(), ] new_gm = program.graph_module diff --git a/docs/source/compiler-memory-planning.md b/docs/source/compiler-memory-planning.md index f7bda678e4c..0f4489654b4 100644 --- a/docs/source/compiler-memory-planning.md +++ b/docs/source/compiler-memory-planning.md @@ -32,7 +32,6 @@ The `MemoryPlanningPass` exposes the option to not memory plan program inputs an program = edge_program.to_executorch( exir.ExecutorchBackendConfig( memory_planning_pass=MemoryPlanningPass( - memory_planning_algo="greedy", alloc_graph_input=False, # Inputs will not be memory planned, the data_ptr for input tensors after model load will be nullptr alloc_graph_output=True, # Outputs will be memory planned, the data_ptr for output tensors after model load will be in the `planned_memory`. ) @@ -77,7 +76,7 @@ Then later when lowering to ExecuTorch you can use your custom plan in the follo program = edge_program.to_executorch( exir.ExecutorchBackendConfig( memory_planning_pass=CustomPoolMemoryPlanningPass( - memory_planning_algo="greedy", + memory_planning_algo=greedy, ) ) ) diff --git a/docs/source/tutorials_source/export-to-executorch-tutorial.py b/docs/source/tutorials_source/export-to-executorch-tutorial.py index 57650abda41..fac3eab08e5 100644 --- a/docs/source/tutorials_source/export-to-executorch-tutorial.py +++ b/docs/source/tutorials_source/export-to-executorch-tutorial.py @@ -523,9 +523,7 @@ def forward(self, a, x, b): executorch_program: ExecutorchProgramManager = edge_program.to_executorch( ExecutorchBackendConfig( passes=[], # User-defined passes - memory_planning_pass=MemoryPlanningPass( - "greedy" - ), # Default memory planning pass + memory_planning_pass=MemoryPlanningPass(), # Default memory planning pass ) ) diff --git a/examples/mediatek/model_export_scripts/llama.py b/examples/mediatek/model_export_scripts/llama.py index 980a502c5ae..b2fef26a4cf 100644 --- a/examples/mediatek/model_export_scripts/llama.py +++ b/examples/mediatek/model_export_scripts/llama.py @@ -365,7 +365,6 @@ def export_to_et_ir( executorch_program = delegated_program.to_executorch( config=exir.ExecutorchBackendConfig( memory_planning_pass=exir.passes.MemoryPlanningPass( - memory_planning_algo="greedy", alloc_graph_input=False, alloc_graph_output=False, ), diff --git a/examples/models/llava/export_llava.py b/examples/models/llava/export_llava.py index b1e9d64ee9f..47a5407cf18 100644 --- a/examples/models/llava/export_llava.py +++ b/examples/models/llava/export_llava.py @@ -233,7 +233,7 @@ def export_all(llava_model: LlavaModel): passes=[ QuantFusionPass(), ], - memory_planning_pass=MemoryPlanningPass("greedy", alloc_graph_input=False), + memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), sym_shape_eval_pass={ "image_encoder": ConstraintBasedSymShapeEvalPass(), "text_model": ConstraintBasedSymShapeEvalPass(), diff --git a/examples/qualcomm/oss_scripts/llama2/llama.py b/examples/qualcomm/oss_scripts/llama2/llama.py index 9712b1c08fd..d74cfa0ef07 100644 --- a/examples/qualcomm/oss_scripts/llama2/llama.py +++ b/examples/qualcomm/oss_scripts/llama2/llama.py @@ -311,7 +311,6 @@ def lowering_modules( # Therefore, won't want to pre-allocate # by memory manager in runtime. memory_planning_pass=MemoryPlanningPass( - memory_planning_algo="greedy", alloc_graph_input=False, alloc_graph_output=False, ), diff --git a/examples/qualcomm/qaihub_scripts/utils/export.py b/examples/qualcomm/qaihub_scripts/utils/export.py index 9dfe6796491..b742f59f1d4 100644 --- a/examples/qualcomm/qaihub_scripts/utils/export.py +++ b/examples/qualcomm/qaihub_scripts/utils/export.py @@ -220,7 +220,6 @@ def compile(args): ) # setup memory planning memory_planning_pass = MemoryPlanningPass( - memory_planning_algo="greedy", alloc_graph_input=args.allocate_graph_io, alloc_graph_output=args.allocate_graph_io, ) diff --git a/examples/qualcomm/utils.py b/examples/qualcomm/utils.py index 088ffe7ea14..9c4cd4453f0 100755 --- a/examples/qualcomm/utils.py +++ b/examples/qualcomm/utils.py @@ -285,7 +285,6 @@ def build_executorch_binary( # Therefore, won't want to pre-allocate # by memory manager in runtime. memory_planning_pass=MemoryPlanningPass( - memory_planning_algo="greedy", alloc_graph_input=not shared_buffer, alloc_graph_output=not shared_buffer, ), diff --git a/exir/capture/_config.py b/exir/capture/_config.py index 11a0d6d069d..24865e7a841 100644 --- a/exir/capture/_config.py +++ b/exir/capture/_config.py @@ -56,9 +56,7 @@ class ExecutorchBackendConfig: # A single memory planning pass can be defined for all the programs in the # EdgeProgramManager or can be defined per program. - memory_planning_pass: Union[PassType, Dict[str, PassType]] = MemoryPlanningPass( - "greedy" - ) + memory_planning_pass: Union[PassType, Dict[str, PassType]] = MemoryPlanningPass() to_out_var_pass: PassType = ToOutVarPass(ignore_to_out_var_failure=False) dynamic_memory_planning_mode: DynamicMemoryPlanningMode = ( DynamicMemoryPlanningMode.UPPER_BOUND diff --git a/exir/emit/test/test_emit.py b/exir/emit/test/test_emit.py index 123896ecdba..4d362d1b516 100644 --- a/exir/emit/test/test_emit.py +++ b/exir/emit/test/test_emit.py @@ -1145,7 +1145,6 @@ def forward(self, k: torch.Tensor) -> torch.Tensor: config = exir.ExecutorchBackendConfig( sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(), memory_planning_pass=MemoryPlanningPass( - memory_planning_algo="greedy", # allow_lifetime_and_storage_overlap: bool = False, alloc_graph_input=True, alloc_graph_output=False, @@ -1606,9 +1605,7 @@ def forward(self, x): ) model = model.to_executorch( config=ExecutorchBackendConfig( - memory_planning_pass=MemoryPlanningPass( - "greedy", alloc_graph_input=False - ), + memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(), ) ) diff --git a/exir/lowered_backend_module.py b/exir/lowered_backend_module.py index e50d3038dac..bc42bba9a26 100644 --- a/exir/lowered_backend_module.py +++ b/exir/lowered_backend_module.py @@ -326,7 +326,7 @@ def program( verifiers=[lowered_exported_program.verifier], ) if memory_planning is None: - memory_planning = MemoryPlanningPass("greedy") + memory_planning = MemoryPlanningPass() exported_program = _transform(exported_program, SpecPropPass(), memory_planning) emitted_program = emit_program( exported_program, emit_stacktrace=emit_stacktrace diff --git a/exir/memory_planning.py b/exir/memory_planning.py index 859bd069013..3c28639ba13 100644 --- a/exir/memory_planning.py +++ b/exir/memory_planning.py @@ -18,12 +18,7 @@ from executorch.exir import memory from executorch.exir.control_flow import while_loop as exir_while from executorch.exir.delegate import executorch_call_delegate -from executorch.exir.error import ( - ExportError, - ExportErrorType, - internal_assert, - InternalError, -) +from executorch.exir.error import internal_assert, InternalError from executorch.exir.operator.convert import is_inplace_variant, is_out_variant from executorch.exir.schema import TensorShapeDynamism from executorch.exir.tensor import TensorSpec @@ -255,17 +250,6 @@ def verify_graph_input_output(self) -> None: ), f"Misallocate graph output {graph_output_allocated} v.s. {self.alloc_graph_output}" -def register_algo(fn: Callable[..., List[int]]) -> Callable[..., List[int]]: - algo_name = fn.__name__ - if algo_name in REGISTERED_ALGOS: - raise ExportError( - ExportErrorType.VIOLATION_OF_SPEC, - f"Re-registering memory planning algorithm {algo_name}", - ) - REGISTERED_ALGOS[algo_name] = fn - return fn - - def _is_out_var_node(node: torch.fx.Node) -> bool: return ( node.op == "call_function" @@ -561,7 +545,6 @@ def get_node_tensor_specs( ] -@register_algo def greedy( graph_module: torch.fx.GraphModule, alignment: int, @@ -615,7 +598,6 @@ def greedy( return total_sizes -@register_algo def naive( graph_module: torch.fx.GraphModule, alignment: int, @@ -656,15 +638,6 @@ def _allocate_buf(bufsizes: List[int], mem_id: int, allocated: int) -> int: return bufsizes -def get_algo(algo_name: str) -> Callable[..., List[int]]: - if algo_name not in REGISTERED_ALGOS: - raise ExportError( - ExportErrorType.NOT_SUPPORTED, - f"Memory planning algorithm '{algo_name}' not found", - ) - return REGISTERED_ALGOS[algo_name] - - def get_cond_nodes(graph_module: torch.fx.GraphModule) -> Iterable[Node]: for nd in graph_module.graph.nodes: if nd.target is torch.ops.higher_order.cond: diff --git a/exir/passes/memory_planning_pass.py b/exir/passes/memory_planning_pass.py index 9295cabcab6..112b8f5fc52 100644 --- a/exir/passes/memory_planning_pass.py +++ b/exir/passes/memory_planning_pass.py @@ -6,7 +6,7 @@ import logging import warnings -from typing import Optional +from typing import Callable, List, Optional import torch from executorch.exir.error import internal_assert @@ -14,8 +14,8 @@ from executorch.exir.memory_planning import ( _is_out_var_node, apply_algo, - get_algo, get_node_tensor_specs, + greedy, Verifier, ) from executorch.exir.operator.convert import get_out_args_from_opoverload @@ -27,7 +27,7 @@ class MemoryPlanningPass(PassBase): def __init__( self, - memory_planning_algo: str = "greedy", + memory_planning_algo: Callable[..., List[int]] = greedy, allow_lifetime_and_storage_overlap: bool = False, alloc_graph_input: bool = True, alloc_graph_output: bool = True, @@ -96,14 +96,13 @@ def run( memory_planning_algo """ self._set_alloc_node_spec(graph_module) - algo = get_algo(self.memory_planning_algo) # TODO(shunting) if people have concern of adding a field to GraphModule # directly, we should define a GraphModule subclass that we can add our # customized fields. Using the graph_module object to convey information across # passes/stages is quite natural and avoid yet another 'context' data structure # to do the job. _ = apply_algo( - algo, + self.memory_planning_algo, graph_module, self.alignment, graph_signature, @@ -125,7 +124,7 @@ def run( self.allow_lifetime_and_storage_overlap ) logging.debug( - f"The {self.memory_planning_algo} algorithm reuses storage for {num_reuse_pairs} pair of tensors" + f"The {getattr(self.memory_planning_algo, '__name__', repr(self.memory_planning_algo))} algorithm reuses storage for {num_reuse_pairs} pair of tensors" ) verifier.verify_graph_input_output() return PassResult(graph_module, True) diff --git a/exir/program/TARGETS b/exir/program/TARGETS index 730c9e93aed..fc73abf1ff7 100644 --- a/exir/program/TARGETS +++ b/exir/program/TARGETS @@ -22,6 +22,7 @@ python_library( "//caffe2:torch", "//executorch/exir:error", "//executorch/exir:graph_module", + "//executorch/exir:pass_base", "//executorch/exir:pass_manager", "//executorch/exir:print_program", "//executorch/exir:schema", diff --git a/exir/program/test/test_program.py b/exir/program/test/test_program.py index edad4b24f1c..73eea7b93ef 100644 --- a/exir/program/test/test_program.py +++ b/exir/program/test/test_program.py @@ -250,12 +250,10 @@ def test_executorch_manager_multi_config(self): def get_executorch_memory_planning_passes() -> Dict[str, MemoryPlanningPass]: return { "forward": MemoryPlanningPass( - memory_planning_algo="greedy", alloc_graph_input=True, alloc_graph_output=False, ), "foo": MemoryPlanningPass( - memory_planning_algo="greedy", alloc_graph_input=False, alloc_graph_output=True, ), diff --git a/exir/tests/test_memory_planning.py b/exir/tests/test_memory_planning.py index 12a0583ab41..ebea0acf0f4 100644 --- a/exir/tests/test_memory_planning.py +++ b/exir/tests/test_memory_planning.py @@ -17,6 +17,8 @@ from executorch.exir.memory_planning import ( filter_nodes, get_node_tensor_specs, + greedy, + naive, Verifier, ) from executorch.exir.pass_base import PassResult @@ -208,7 +210,7 @@ def forward(self, a: torch.Tensor) -> torch.Tensor: def maketest( module_cls: Type[torch.nn.Module], - criteria: Optional[List[Tuple[str, bool]]] = None, + criteria: Optional[List[Tuple[Callable[..., List[int]], bool]]] = None, extra_check: Optional[Callable[..., None]] = None, use_functionalization: bool = True, alloc_graph_input: bool = True, @@ -222,13 +224,15 @@ def wrapper(self: "TestMemoryPlanning") -> None: if not criteria: criteria = [ # naive algorithm does not reuse tensor storages - ("naive", False), + (naive, False), # greedy algorithm should reuse tensor storages in the testing model - ("greedy", True), + (greedy, True), ] for algo, expect_reuse in criteria: - print(f"algo {algo}, expect_reuse {expect_reuse}") + print( + f"algo {getattr(algo, '__name__', repr(algo))}, expect_reuse {expect_reuse}" + ) eager_module = module_cls().eval() inputs = eager_module.get_random_inputs() graph_module = ( @@ -353,8 +357,8 @@ def verify_overlap_placeholders( test_return_two: Callable[..., None] = maketest( ModuleReturnTwo, criteria=[ - ("naive", False), - ("greedy", True), + (naive, False), + (greedy, True), ], ) @@ -363,8 +367,8 @@ def verify_overlap_placeholders( test_list_arg: Callable[..., None] = maketest( ModuleListArg, criteria=[ - ("naive", False), - ("greedy", True), + (naive, False), + (greedy, True), ], extra_check=ModuleListArg.extra_check, ) @@ -466,12 +470,12 @@ def quantize(self, eager_model: nn.Module) -> nn.Module: @parameterized.expand( [ ( - "naive", + naive, [(1, 0), (3, 0), (1, 4), (3, 4), (1, 8)], [0, 12, 0, 8], ), ( - "greedy", + greedy, [(1, 0), (3, 0), (1, 4), (3, 4), (1, 0)], [0, 8, 0, 8], ), @@ -479,7 +483,7 @@ def quantize(self, eager_model: nn.Module) -> nn.Module: ) def test_multiple_pools( self, - algo: str, + algo: Callable[..., List[int]], expected_allocs: List[Tuple[int, int]], expected_bufsizes: List[int], ) -> None: @@ -550,9 +554,7 @@ def count_planned_inputs( ep_no_input_planning = to_edge(export(model, inputs)).to_executorch( config=ExecutorchBackendConfig( - memory_planning_pass=MemoryPlanningPass( - "greedy", alloc_graph_input=False - ), + memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(), ) ) @@ -572,9 +574,7 @@ def count_planned_inputs( ep_input_planning = to_edge(export(model, inputs)).to_executorch( config=ExecutorchBackendConfig( - memory_planning_pass=MemoryPlanningPass( - "greedy", alloc_graph_input=True - ), + memory_planning_pass=MemoryPlanningPass(alloc_graph_input=True), sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(), ) ) diff --git a/exir/tests/test_passes.py b/exir/tests/test_passes.py index a167a67dd94..79578763475 100644 --- a/exir/tests/test_passes.py +++ b/exir/tests/test_passes.py @@ -713,7 +713,7 @@ def test_alloc_node_spec(self) -> None: self.assertIsNotNone(new_gm_res) new_gm = new_gm_res.graph_module - new_gm_res = MemoryPlanningPass("greedy")(new_gm) + new_gm_res = MemoryPlanningPass()(new_gm) self.assertIsNotNone(new_gm_res) new_gm = new_gm_res.graph_module diff --git a/exir/tests/test_remove_view_copy.py b/exir/tests/test_remove_view_copy.py index f64a1f19981..0925a8abc89 100644 --- a/exir/tests/test_remove_view_copy.py +++ b/exir/tests/test_remove_view_copy.py @@ -48,9 +48,7 @@ def test_disable(self) -> None: etpm = to_edge(ep).to_executorch( config=ExecutorchBackendConfig( remove_view_copy=False, - memory_planning_pass=MemoryPlanningPass( - "greedy", alloc_graph_input=False - ), + memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), ), ) @@ -72,9 +70,7 @@ def test_output_matches(self) -> None: etpm_remove = epm_remove.to_executorch( config=ExecutorchBackendConfig( remove_view_copy=True, - memory_planning_pass=MemoryPlanningPass( - "greedy", alloc_graph_input=False - ), + memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), ), ) @@ -82,9 +78,7 @@ def test_output_matches(self) -> None: etpm_no_remove = epm_no_remove.to_executorch( config=ExecutorchBackendConfig( remove_view_copy=True, - memory_planning_pass=MemoryPlanningPass( - "greedy", alloc_graph_input=False - ), + memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), ), ) @@ -107,9 +101,7 @@ def test_spec(self) -> None: etpm = to_edge(ep).to_executorch( config=ExecutorchBackendConfig( remove_view_copy=True, - memory_planning_pass=MemoryPlanningPass( - "greedy", alloc_graph_input=False - ), + memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), ), ) diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index 338d997297d..91ee2dc733b 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -389,9 +389,7 @@ def to_executorch(self) -> "LLMEdgeManager": ConvertToLinearPass(), QuantFusionPass(), ], - memory_planning_pass=MemoryPlanningPass( - "greedy", alloc_graph_input=False - ), + memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(), ) ) diff --git a/test/end2end/exported_module.py b/test/end2end/exported_module.py index 6e6b97b7186..2365450ae59 100644 --- a/test/end2end/exported_module.py +++ b/test/end2end/exported_module.py @@ -147,7 +147,7 @@ def return_wrapper(): for method in methods: method_name_to_dynamic_shapes[method] = trace_dynamic_shapes - memory_planning_pass = MemoryPlanningPass("greedy") + memory_planning_pass = MemoryPlanningPass() if hasattr(eager_module, "get_memory_planning_pass"): memory_planning_pass = eager_module.get_memory_planning_pass() diff --git a/test/models/export_program.py b/test/models/export_program.py index d753475b829..caea394f33c 100644 --- a/test/models/export_program.py +++ b/test/models/export_program.py @@ -121,7 +121,6 @@ def get_dynamic_shapes(self): def get_memory_planning_pass(self): return MemoryPlanningPass( - memory_planning_algo="greedy", alloc_graph_input=False, alloc_graph_output=False, )