From 45993dbf11de852731f07c63ac92092f6dbf3469 Mon Sep 17 00:00:00 2001 From: Eashan Garg Date: Tue, 17 Sep 2024 11:38:02 -0700 Subject: [PATCH] Change memory planning API to accept full algorithm as argument as opposed to string name (#4727) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/4727 Executorch memory planning currently accepts a string identifier to represent the desired algorithm. However, this makes it difficult to pass custom arguments to write more customized memory planning algorithms. This change allows users to pass the full memory planning function as an argument as opposed to just the string identifier. Core changes in: - fbcode/executorch/exir/passes/memory_planning_pass.py - fbcode/executorch/exir/tests/test_memory_planning.py Remaining changes are just to enforce compliance with new API at all call sites in codebase NOTE: A less intrusive change could be to allow argument to be either string or entire custom functions. I opted for just passing only functions to simplify and avoid confusion Reviewed By: zonglinpeng, hsharma35, mcremon-meta Differential Revision: D60433641 fbshipit-source-id: 0fe3677b7c3f4c3763cb1b4fe6d28ef814f2ecf9 (cherry picked from commit 618466ed56191fb21fc581c012363ed79911ea13) --- backends/qualcomm/tests/utils.py | 1 - backends/vulkan/vulkan_preprocess.py | 2 +- docs/source/compiler-memory-planning.md | 3 +- .../export-to-executorch-tutorial.py | 4 +-- .../mediatek/model_export_scripts/llama.py | 1 - examples/models/llava/export_llava.py | 2 +- examples/qualcomm/oss_scripts/llama2/llama.py | 1 - .../qualcomm/qaihub_scripts/utils/export.py | 1 - examples/qualcomm/utils.py | 1 - exir/capture/_config.py | 4 +-- exir/emit/test/test_emit.py | 5 +-- exir/lowered_backend_module.py | 2 +- exir/memory_planning.py | 29 +--------------- exir/passes/memory_planning_pass.py | 11 +++--- exir/program/TARGETS | 1 + exir/program/test/test_program.py | 2 -- exir/tests/test_memory_planning.py | 34 +++++++++---------- exir/tests/test_passes.py | 2 +- exir/tests/test_remove_view_copy.py | 16 +++------ extension/llm/export/builder.py | 4 +-- test/end2end/exported_module.py | 2 +- test/models/export_program.py | 1 - 22 files changed, 38 insertions(+), 91 deletions(-) diff --git a/backends/qualcomm/tests/utils.py b/backends/qualcomm/tests/utils.py index 6e4d1c7502f..7209b0a2678 100644 --- a/backends/qualcomm/tests/utils.py +++ b/backends/qualcomm/tests/utils.py @@ -350,7 +350,6 @@ def lower_module_and_test_output( # Therefore, won't want to pre-allocate # by memory manager in runtime. memory_planning_pass=MemoryPlanningPass( - memory_planning_algo="greedy", alloc_graph_input=not self.shared_buffer, alloc_graph_output=not self.shared_buffer, ), diff --git a/backends/vulkan/vulkan_preprocess.py b/backends/vulkan/vulkan_preprocess.py index 1865c32acd7..7e85c25faee 100644 --- a/backends/vulkan/vulkan_preprocess.py +++ b/backends/vulkan/vulkan_preprocess.py @@ -57,7 +57,7 @@ def preprocess( # noqa: C901 MeanToSumDiv(), SpecPropPass(), ConstraintBasedSymShapeEvalPass(), - MemoryPlanningPass("greedy"), + MemoryPlanningPass(), ] new_gm = program.graph_module diff --git a/docs/source/compiler-memory-planning.md b/docs/source/compiler-memory-planning.md index f7bda678e4c..0f4489654b4 100644 --- a/docs/source/compiler-memory-planning.md +++ b/docs/source/compiler-memory-planning.md @@ -32,7 +32,6 @@ The `MemoryPlanningPass` exposes the option to not memory plan program inputs an program = edge_program.to_executorch( exir.ExecutorchBackendConfig( memory_planning_pass=MemoryPlanningPass( - memory_planning_algo="greedy", alloc_graph_input=False, # Inputs will not be memory planned, the data_ptr for input tensors after model load will be nullptr alloc_graph_output=True, # Outputs will be memory planned, the data_ptr for output tensors after model load will be in the `planned_memory`. ) @@ -77,7 +76,7 @@ Then later when lowering to ExecuTorch you can use your custom plan in the follo program = edge_program.to_executorch( exir.ExecutorchBackendConfig( memory_planning_pass=CustomPoolMemoryPlanningPass( - memory_planning_algo="greedy", + memory_planning_algo=greedy, ) ) ) diff --git a/docs/source/tutorials_source/export-to-executorch-tutorial.py b/docs/source/tutorials_source/export-to-executorch-tutorial.py index 57650abda41..fac3eab08e5 100644 --- a/docs/source/tutorials_source/export-to-executorch-tutorial.py +++ b/docs/source/tutorials_source/export-to-executorch-tutorial.py @@ -523,9 +523,7 @@ def forward(self, a, x, b): executorch_program: ExecutorchProgramManager = edge_program.to_executorch( ExecutorchBackendConfig( passes=[], # User-defined passes - memory_planning_pass=MemoryPlanningPass( - "greedy" - ), # Default memory planning pass + memory_planning_pass=MemoryPlanningPass(), # Default memory planning pass ) ) diff --git a/examples/mediatek/model_export_scripts/llama.py b/examples/mediatek/model_export_scripts/llama.py index 980a502c5ae..b2fef26a4cf 100644 --- a/examples/mediatek/model_export_scripts/llama.py +++ b/examples/mediatek/model_export_scripts/llama.py @@ -365,7 +365,6 @@ def export_to_et_ir( executorch_program = delegated_program.to_executorch( config=exir.ExecutorchBackendConfig( memory_planning_pass=exir.passes.MemoryPlanningPass( - memory_planning_algo="greedy", alloc_graph_input=False, alloc_graph_output=False, ), diff --git a/examples/models/llava/export_llava.py b/examples/models/llava/export_llava.py index b1e9d64ee9f..47a5407cf18 100644 --- a/examples/models/llava/export_llava.py +++ b/examples/models/llava/export_llava.py @@ -233,7 +233,7 @@ def export_all(llava_model: LlavaModel): passes=[ QuantFusionPass(), ], - memory_planning_pass=MemoryPlanningPass("greedy", alloc_graph_input=False), + memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), sym_shape_eval_pass={ "image_encoder": ConstraintBasedSymShapeEvalPass(), "text_model": ConstraintBasedSymShapeEvalPass(), diff --git a/examples/qualcomm/oss_scripts/llama2/llama.py b/examples/qualcomm/oss_scripts/llama2/llama.py index 9712b1c08fd..d74cfa0ef07 100644 --- a/examples/qualcomm/oss_scripts/llama2/llama.py +++ b/examples/qualcomm/oss_scripts/llama2/llama.py @@ -311,7 +311,6 @@ def lowering_modules( # Therefore, won't want to pre-allocate # by memory manager in runtime. memory_planning_pass=MemoryPlanningPass( - memory_planning_algo="greedy", alloc_graph_input=False, alloc_graph_output=False, ), diff --git a/examples/qualcomm/qaihub_scripts/utils/export.py b/examples/qualcomm/qaihub_scripts/utils/export.py index 9dfe6796491..b742f59f1d4 100644 --- a/examples/qualcomm/qaihub_scripts/utils/export.py +++ b/examples/qualcomm/qaihub_scripts/utils/export.py @@ -220,7 +220,6 @@ def compile(args): ) # setup memory planning memory_planning_pass = MemoryPlanningPass( - memory_planning_algo="greedy", alloc_graph_input=args.allocate_graph_io, alloc_graph_output=args.allocate_graph_io, ) diff --git a/examples/qualcomm/utils.py b/examples/qualcomm/utils.py index 088ffe7ea14..9c4cd4453f0 100755 --- a/examples/qualcomm/utils.py +++ b/examples/qualcomm/utils.py @@ -285,7 +285,6 @@ def build_executorch_binary( # Therefore, won't want to pre-allocate # by memory manager in runtime. memory_planning_pass=MemoryPlanningPass( - memory_planning_algo="greedy", alloc_graph_input=not shared_buffer, alloc_graph_output=not shared_buffer, ), diff --git a/exir/capture/_config.py b/exir/capture/_config.py index 11a0d6d069d..24865e7a841 100644 --- a/exir/capture/_config.py +++ b/exir/capture/_config.py @@ -56,9 +56,7 @@ class ExecutorchBackendConfig: # A single memory planning pass can be defined for all the programs in the # EdgeProgramManager or can be defined per program. - memory_planning_pass: Union[PassType, Dict[str, PassType]] = MemoryPlanningPass( - "greedy" - ) + memory_planning_pass: Union[PassType, Dict[str, PassType]] = MemoryPlanningPass() to_out_var_pass: PassType = ToOutVarPass(ignore_to_out_var_failure=False) dynamic_memory_planning_mode: DynamicMemoryPlanningMode = ( DynamicMemoryPlanningMode.UPPER_BOUND diff --git a/exir/emit/test/test_emit.py b/exir/emit/test/test_emit.py index 123896ecdba..4d362d1b516 100644 --- a/exir/emit/test/test_emit.py +++ b/exir/emit/test/test_emit.py @@ -1145,7 +1145,6 @@ def forward(self, k: torch.Tensor) -> torch.Tensor: config = exir.ExecutorchBackendConfig( sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(), memory_planning_pass=MemoryPlanningPass( - memory_planning_algo="greedy", # allow_lifetime_and_storage_overlap: bool = False, alloc_graph_input=True, alloc_graph_output=False, @@ -1606,9 +1605,7 @@ def forward(self, x): ) model = model.to_executorch( config=ExecutorchBackendConfig( - memory_planning_pass=MemoryPlanningPass( - "greedy", alloc_graph_input=False - ), + memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(), ) ) diff --git a/exir/lowered_backend_module.py b/exir/lowered_backend_module.py index e50d3038dac..bc42bba9a26 100644 --- a/exir/lowered_backend_module.py +++ b/exir/lowered_backend_module.py @@ -326,7 +326,7 @@ def program( verifiers=[lowered_exported_program.verifier], ) if memory_planning is None: - memory_planning = MemoryPlanningPass("greedy") + memory_planning = MemoryPlanningPass() exported_program = _transform(exported_program, SpecPropPass(), memory_planning) emitted_program = emit_program( exported_program, emit_stacktrace=emit_stacktrace diff --git a/exir/memory_planning.py b/exir/memory_planning.py index 859bd069013..3c28639ba13 100644 --- a/exir/memory_planning.py +++ b/exir/memory_planning.py @@ -18,12 +18,7 @@ from executorch.exir import memory from executorch.exir.control_flow import while_loop as exir_while from executorch.exir.delegate import executorch_call_delegate -from executorch.exir.error import ( - ExportError, - ExportErrorType, - internal_assert, - InternalError, -) +from executorch.exir.error import internal_assert, InternalError from executorch.exir.operator.convert import is_inplace_variant, is_out_variant from executorch.exir.schema import TensorShapeDynamism from executorch.exir.tensor import TensorSpec @@ -255,17 +250,6 @@ def verify_graph_input_output(self) -> None: ), f"Misallocate graph output {graph_output_allocated} v.s. {self.alloc_graph_output}" -def register_algo(fn: Callable[..., List[int]]) -> Callable[..., List[int]]: - algo_name = fn.__name__ - if algo_name in REGISTERED_ALGOS: - raise ExportError( - ExportErrorType.VIOLATION_OF_SPEC, - f"Re-registering memory planning algorithm {algo_name}", - ) - REGISTERED_ALGOS[algo_name] = fn - return fn - - def _is_out_var_node(node: torch.fx.Node) -> bool: return ( node.op == "call_function" @@ -561,7 +545,6 @@ def get_node_tensor_specs( ] -@register_algo def greedy( graph_module: torch.fx.GraphModule, alignment: int, @@ -615,7 +598,6 @@ def greedy( return total_sizes -@register_algo def naive( graph_module: torch.fx.GraphModule, alignment: int, @@ -656,15 +638,6 @@ def _allocate_buf(bufsizes: List[int], mem_id: int, allocated: int) -> int: return bufsizes -def get_algo(algo_name: str) -> Callable[..., List[int]]: - if algo_name not in REGISTERED_ALGOS: - raise ExportError( - ExportErrorType.NOT_SUPPORTED, - f"Memory planning algorithm '{algo_name}' not found", - ) - return REGISTERED_ALGOS[algo_name] - - def get_cond_nodes(graph_module: torch.fx.GraphModule) -> Iterable[Node]: for nd in graph_module.graph.nodes: if nd.target is torch.ops.higher_order.cond: diff --git a/exir/passes/memory_planning_pass.py b/exir/passes/memory_planning_pass.py index 9295cabcab6..112b8f5fc52 100644 --- a/exir/passes/memory_planning_pass.py +++ b/exir/passes/memory_planning_pass.py @@ -6,7 +6,7 @@ import logging import warnings -from typing import Optional +from typing import Callable, List, Optional import torch from executorch.exir.error import internal_assert @@ -14,8 +14,8 @@ from executorch.exir.memory_planning import ( _is_out_var_node, apply_algo, - get_algo, get_node_tensor_specs, + greedy, Verifier, ) from executorch.exir.operator.convert import get_out_args_from_opoverload @@ -27,7 +27,7 @@ class MemoryPlanningPass(PassBase): def __init__( self, - memory_planning_algo: str = "greedy", + memory_planning_algo: Callable[..., List[int]] = greedy, allow_lifetime_and_storage_overlap: bool = False, alloc_graph_input: bool = True, alloc_graph_output: bool = True, @@ -96,14 +96,13 @@ def run( memory_planning_algo """ self._set_alloc_node_spec(graph_module) - algo = get_algo(self.memory_planning_algo) # TODO(shunting) if people have concern of adding a field to GraphModule # directly, we should define a GraphModule subclass that we can add our # customized fields. Using the graph_module object to convey information across # passes/stages is quite natural and avoid yet another 'context' data structure # to do the job. _ = apply_algo( - algo, + self.memory_planning_algo, graph_module, self.alignment, graph_signature, @@ -125,7 +124,7 @@ def run( self.allow_lifetime_and_storage_overlap ) logging.debug( - f"The {self.memory_planning_algo} algorithm reuses storage for {num_reuse_pairs} pair of tensors" + f"The {getattr(self.memory_planning_algo, '__name__', repr(self.memory_planning_algo))} algorithm reuses storage for {num_reuse_pairs} pair of tensors" ) verifier.verify_graph_input_output() return PassResult(graph_module, True) diff --git a/exir/program/TARGETS b/exir/program/TARGETS index 730c9e93aed..fc73abf1ff7 100644 --- a/exir/program/TARGETS +++ b/exir/program/TARGETS @@ -22,6 +22,7 @@ python_library( "//caffe2:torch", "//executorch/exir:error", "//executorch/exir:graph_module", + "//executorch/exir:pass_base", "//executorch/exir:pass_manager", "//executorch/exir:print_program", "//executorch/exir:schema", diff --git a/exir/program/test/test_program.py b/exir/program/test/test_program.py index edad4b24f1c..73eea7b93ef 100644 --- a/exir/program/test/test_program.py +++ b/exir/program/test/test_program.py @@ -250,12 +250,10 @@ def test_executorch_manager_multi_config(self): def get_executorch_memory_planning_passes() -> Dict[str, MemoryPlanningPass]: return { "forward": MemoryPlanningPass( - memory_planning_algo="greedy", alloc_graph_input=True, alloc_graph_output=False, ), "foo": MemoryPlanningPass( - memory_planning_algo="greedy", alloc_graph_input=False, alloc_graph_output=True, ), diff --git a/exir/tests/test_memory_planning.py b/exir/tests/test_memory_planning.py index 12a0583ab41..ebea0acf0f4 100644 --- a/exir/tests/test_memory_planning.py +++ b/exir/tests/test_memory_planning.py @@ -17,6 +17,8 @@ from executorch.exir.memory_planning import ( filter_nodes, get_node_tensor_specs, + greedy, + naive, Verifier, ) from executorch.exir.pass_base import PassResult @@ -208,7 +210,7 @@ def forward(self, a: torch.Tensor) -> torch.Tensor: def maketest( module_cls: Type[torch.nn.Module], - criteria: Optional[List[Tuple[str, bool]]] = None, + criteria: Optional[List[Tuple[Callable[..., List[int]], bool]]] = None, extra_check: Optional[Callable[..., None]] = None, use_functionalization: bool = True, alloc_graph_input: bool = True, @@ -222,13 +224,15 @@ def wrapper(self: "TestMemoryPlanning") -> None: if not criteria: criteria = [ # naive algorithm does not reuse tensor storages - ("naive", False), + (naive, False), # greedy algorithm should reuse tensor storages in the testing model - ("greedy", True), + (greedy, True), ] for algo, expect_reuse in criteria: - print(f"algo {algo}, expect_reuse {expect_reuse}") + print( + f"algo {getattr(algo, '__name__', repr(algo))}, expect_reuse {expect_reuse}" + ) eager_module = module_cls().eval() inputs = eager_module.get_random_inputs() graph_module = ( @@ -353,8 +357,8 @@ def verify_overlap_placeholders( test_return_two: Callable[..., None] = maketest( ModuleReturnTwo, criteria=[ - ("naive", False), - ("greedy", True), + (naive, False), + (greedy, True), ], ) @@ -363,8 +367,8 @@ def verify_overlap_placeholders( test_list_arg: Callable[..., None] = maketest( ModuleListArg, criteria=[ - ("naive", False), - ("greedy", True), + (naive, False), + (greedy, True), ], extra_check=ModuleListArg.extra_check, ) @@ -466,12 +470,12 @@ def quantize(self, eager_model: nn.Module) -> nn.Module: @parameterized.expand( [ ( - "naive", + naive, [(1, 0), (3, 0), (1, 4), (3, 4), (1, 8)], [0, 12, 0, 8], ), ( - "greedy", + greedy, [(1, 0), (3, 0), (1, 4), (3, 4), (1, 0)], [0, 8, 0, 8], ), @@ -479,7 +483,7 @@ def quantize(self, eager_model: nn.Module) -> nn.Module: ) def test_multiple_pools( self, - algo: str, + algo: Callable[..., List[int]], expected_allocs: List[Tuple[int, int]], expected_bufsizes: List[int], ) -> None: @@ -550,9 +554,7 @@ def count_planned_inputs( ep_no_input_planning = to_edge(export(model, inputs)).to_executorch( config=ExecutorchBackendConfig( - memory_planning_pass=MemoryPlanningPass( - "greedy", alloc_graph_input=False - ), + memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(), ) ) @@ -572,9 +574,7 @@ def count_planned_inputs( ep_input_planning = to_edge(export(model, inputs)).to_executorch( config=ExecutorchBackendConfig( - memory_planning_pass=MemoryPlanningPass( - "greedy", alloc_graph_input=True - ), + memory_planning_pass=MemoryPlanningPass(alloc_graph_input=True), sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(), ) ) diff --git a/exir/tests/test_passes.py b/exir/tests/test_passes.py index a167a67dd94..79578763475 100644 --- a/exir/tests/test_passes.py +++ b/exir/tests/test_passes.py @@ -713,7 +713,7 @@ def test_alloc_node_spec(self) -> None: self.assertIsNotNone(new_gm_res) new_gm = new_gm_res.graph_module - new_gm_res = MemoryPlanningPass("greedy")(new_gm) + new_gm_res = MemoryPlanningPass()(new_gm) self.assertIsNotNone(new_gm_res) new_gm = new_gm_res.graph_module diff --git a/exir/tests/test_remove_view_copy.py b/exir/tests/test_remove_view_copy.py index f64a1f19981..0925a8abc89 100644 --- a/exir/tests/test_remove_view_copy.py +++ b/exir/tests/test_remove_view_copy.py @@ -48,9 +48,7 @@ def test_disable(self) -> None: etpm = to_edge(ep).to_executorch( config=ExecutorchBackendConfig( remove_view_copy=False, - memory_planning_pass=MemoryPlanningPass( - "greedy", alloc_graph_input=False - ), + memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), ), ) @@ -72,9 +70,7 @@ def test_output_matches(self) -> None: etpm_remove = epm_remove.to_executorch( config=ExecutorchBackendConfig( remove_view_copy=True, - memory_planning_pass=MemoryPlanningPass( - "greedy", alloc_graph_input=False - ), + memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), ), ) @@ -82,9 +78,7 @@ def test_output_matches(self) -> None: etpm_no_remove = epm_no_remove.to_executorch( config=ExecutorchBackendConfig( remove_view_copy=True, - memory_planning_pass=MemoryPlanningPass( - "greedy", alloc_graph_input=False - ), + memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), ), ) @@ -107,9 +101,7 @@ def test_spec(self) -> None: etpm = to_edge(ep).to_executorch( config=ExecutorchBackendConfig( remove_view_copy=True, - memory_planning_pass=MemoryPlanningPass( - "greedy", alloc_graph_input=False - ), + memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), ), ) diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index 338d997297d..91ee2dc733b 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -389,9 +389,7 @@ def to_executorch(self) -> "LLMEdgeManager": ConvertToLinearPass(), QuantFusionPass(), ], - memory_planning_pass=MemoryPlanningPass( - "greedy", alloc_graph_input=False - ), + memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(), ) ) diff --git a/test/end2end/exported_module.py b/test/end2end/exported_module.py index 6e6b97b7186..2365450ae59 100644 --- a/test/end2end/exported_module.py +++ b/test/end2end/exported_module.py @@ -147,7 +147,7 @@ def return_wrapper(): for method in methods: method_name_to_dynamic_shapes[method] = trace_dynamic_shapes - memory_planning_pass = MemoryPlanningPass("greedy") + memory_planning_pass = MemoryPlanningPass() if hasattr(eager_module, "get_memory_planning_pass"): memory_planning_pass = eager_module.get_memory_planning_pass() diff --git a/test/models/export_program.py b/test/models/export_program.py index d753475b829..caea394f33c 100644 --- a/test/models/export_program.py +++ b/test/models/export_program.py @@ -121,7 +121,6 @@ def get_dynamic_shapes(self): def get_memory_planning_pass(self): return MemoryPlanningPass( - memory_planning_algo="greedy", alloc_graph_input=False, alloc_graph_output=False, )