diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 0b7064c4dd6..1c1cf82d192 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -23,6 +23,7 @@ import torch from executorch.devtools.etrecord import generate_etrecord +from executorch.exir.passes.init_mutable_pass import InitializedMutableBufferPass from executorch.extension.llm.export.builder import DType, LLMEdgeManager @@ -775,6 +776,9 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901 for partitioner in partitioners: logging.info(f"--> {partitioner.__class__.__name__}") + additional_passes = [] + if args.model in TORCHTUNE_DEFINED_MODELS: + additional_passes = [InitializedMutableBufferPass(["cache_pos"])] if args.generate_etrecord: if not builder_exported_to_edge.edge_manager: raise ValueError("Unable to generate etrecord due to missing edge manager.") @@ -789,7 +793,9 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901 # pyre-fixme[16]: Module `backends` has no attribute `qualcomm`. canonicalize_program(builder.edge_manager.exported_program()) - builder = builder.to_executorch() + builder = builder.to_executorch( + passes=additional_passes, + ) # Generate ETRecord if edge_manager_copy: @@ -807,7 +813,7 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901 # pyre-fixme[16]: Module `backends` has no attribute `qualcomm`. canonicalize_program(builder.edge_manager.exported_program()) - builder = builder.to_executorch() + builder = builder.to_executorch(passes=additional_passes) if args.profile_memory: generate_memory_trace(builder.export_program, "memory_profile.json") diff --git a/examples/models/llama3_2_vision/runner/native.py b/examples/models/llama3_2_vision/runner/native.py index 9a28c94f9c2..2b4d709f9b4 100644 --- a/examples/models/llama3_2_vision/runner/native.py +++ b/examples/models/llama3_2_vision/runner/native.py @@ -18,13 +18,15 @@ TorchTuneLlamaRunner, ) -from executorch.extension.pybindings.portable_lib import _load_for_executorch +from executorch.extension.pybindings.portable_lib import ( + _load_for_executorch_from_buffer, +) # Load custom ops and quantized ops. from executorch.extension.pybindings import portable_lib # noqa # usort: skip # Note: import this after portable_lib -from executorch.extension.llm.custom_ops import sdpa_with_kv_cache # noqa # usort: skip +from executorch.extension.llm.custom_ops import custom_ops # noqa # usort: skip from executorch.kernels import quantized # noqa @@ -43,7 +45,17 @@ def __init__(self, args): use_kv_cache=args.kv_cache, vocab_size=params["vocab_size"], ) - self.model = _load_for_executorch(args.pte) + # Save the loaded model bytes to prevent data from going out of + # scope after the `with` and getting cleaned up by Python's + # garbage collector. + self.model_bytes = None + with open(args.pte, "rb") as f: + self.model_bytes = f.read() + # Need to use _load_for_executorch_from_buffer instead of + # _load_for_executorch because the latter uses MmapDataLoader, + # which doesn't have load_into() implemented, which is needed + # for loading initialized mutable buffers. + self.model = _load_for_executorch_from_buffer(self.model_bytes) self.use_kv_cache = args.kv_cache def forward( diff --git a/exir/emit/_emitter.py b/exir/emit/_emitter.py index 88247d2a274..d08e68fa731 100644 --- a/exir/emit/_emitter.py +++ b/exir/emit/_emitter.py @@ -1575,7 +1575,8 @@ def _find_fqn_for_placeholder( warnings.warn( "Mutation on a buffer in the model is detected. ExecuTorch assumes " "buffers that are mutated in the graph have a meaningless initial state, " - "only the shape and dtype will be serialized.", + "only the shape and dtype will be serialized, unless a pass which sets " + 'meta["et_init_buffer"] to True such as InitializedMutableBufferPass is run.', UserWarning, stacklevel=1, ) @@ -1602,6 +1603,7 @@ def placeholder( """ spec = self.node.meta["spec"] constant_tag = self.node.meta.get("constant_tag", None) + initialize_buffer = self.node.meta.get("et_init_buffer", None) is_user_input = True if isinstance(target, str) and isinstance(spec, TensorSpec): @@ -1655,7 +1657,10 @@ def placeholder( spec.storage = real_tensor.untyped_storage() # User inputs and mutable buffers are not constants, other buffers or parameters are. - spec.const = not (is_user_input or is_mutable_buffer) + if initialize_buffer and is_mutable_buffer: + spec.const = True + else: + spec.const = not (is_user_input or is_mutable_buffer) evalue = ( self._tensor_spec_to_evalue(spec, constant_tag) diff --git a/exir/emit/test/test_emit.py b/exir/emit/test/test_emit.py index a645fa53779..0da40859146 100644 --- a/exir/emit/test/test_emit.py +++ b/exir/emit/test/test_emit.py @@ -9,6 +9,7 @@ import typing import unittest from contextlib import contextmanager +from copy import deepcopy from typing import List, Optional, Tuple import executorch.exir as exir @@ -31,6 +32,7 @@ from executorch.exir.error import InternalError from executorch.exir.passes import MemoryPlanningPass from executorch.exir.passes.constant_prop_pass import constant_prop_pass +from executorch.exir.passes.init_mutable_pass import InitializedMutableBufferPass from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass from executorch.exir.print_program import pretty_print, print_program # noqa from executorch.exir.schema import ( @@ -56,6 +58,7 @@ from executorch.extension.pybindings.portable_lib import ( _load_for_executorch_from_buffer, ) +from executorch.runtime import Runtime from functorch.experimental import control_flow from torch import nn @@ -243,6 +246,56 @@ def forward(self, x): ) self.assertIsInstance(program.execution_plan[0].values[outputs[6]].val, Null) + def test_initialized_mutable_buffer(self): + """Test that mutable buffers can hold meaningful initialized state.""" + + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + # Mutable buffer with non-empty initial state. + self.register_buffer("cache_pos", torch.arange(0, 10)) + + def forward(self, x): + self.cache_pos.add_(1) + return self.cache_pos + + m = TestModule() + example_inputs = (torch.ones(10),) + ep = torch.export.export(m, example_inputs) + edge = to_edge( + ep, + compile_config=EdgeCompileConfig( + _check_ir_validity=False, + ), + ) + + # Save a copy of the edge program since to_executorch is + # stateful to some degree. + edge_copy = deepcopy(edge) + et_config = ExecutorchBackendConfig( + passes=[InitializedMutableBufferPass(["cache_pos"])], + ) + et_program_init_pass = edge.to_executorch(config=et_config) + et_program_regular = edge_copy.to_executorch() + + runtime = Runtime.get() + program_init_pass = runtime.load_program(et_program_init_pass.buffer) + method_init_pass = program_init_pass.load_method("forward") + + program_regular = runtime.load_program(et_program_regular.buffer) + method_regular = program_regular.load_method("forward") + + # Test that the mutable buffer is initialized. + torch.allclose( + method_init_pass.execute((example_inputs))[0], torch.arange(1, 11) + ) + # Test that the mutable buffer is uninitialized and starts with default zeros, + # we test equality with torch.ones because of the mutation += 1 in the model forward. + torch.allclose( + method_regular.execute((example_inputs))[0], + torch.ones(10, dtype=torch.int64), + ) + def test_int_list_input(self): class M(torch.nn.Module): def forward(self, x, y, z): diff --git a/exir/passes/init_mutable_pass.py b/exir/passes/init_mutable_pass.py new file mode 100644 index 00000000000..e6778259d04 --- /dev/null +++ b/exir/passes/init_mutable_pass.py @@ -0,0 +1,32 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import List + +from executorch.exir.pass_base import ExportPass + + +class InitializedMutableBufferPass(ExportPass): + """ + If a buffer has a name that within a specified list, set meta["et_init_buffer"] + to True, which provides the mutable buffer with an initialized state. + + As an example, a module with `self.register_buffer("cache_pos", torch.arange(10))` + when patterns = ["cache_pos"] would have its initial state set instead of being + left uninitialized by default. + """ + + def __init__(self, patterns: List[str]) -> None: + super().__init__() + self.patterns = patterns + + def placeholder(self, name: str, arg, meta): + for pattern in self.patterns: + if pattern in name: + meta["et_init_buffer"] = True + + return super().placeholder(name, arg, meta) diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index 7cab3c77b81..243f6f91fbb 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -27,6 +27,7 @@ from executorch.exir.backend.utils import format_delegated_graph from executorch.exir.capture._config import EdgeCompileConfig, ExecutorchBackendConfig +from executorch.exir.pass_manager import PassType from executorch.exir.passes import MemoryPlanningPass from executorch.exir.passes.quant_fusion_pass import QuantFusionPass from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass @@ -415,21 +416,27 @@ def to_backend(self, partitioners: Optional[List[Partitioner]]) -> "LLMEdgeManag return self - def to_executorch(self) -> "LLMEdgeManager": + def to_executorch( + self, passes: Optional[List[PassType]] = None + ) -> "LLMEdgeManager": """ Lower the model to executorch and get an ExecutorchProgram. """ assert self.edge_manager, "Need to run export_to_edge() first" + to_executorch_passes = [ + # If there are Linear operations left in the graph, let's execute + # them with the optimized op_linear rather than materializing a + # transpose followed by a regular op_mm. + ConvertToLinearPass(), + QuantFusionPass(), + ] + if passes: + to_executorch_passes.extend(passes) + self.export_program = self.edge_manager.to_executorch( ExecutorchBackendConfig( extract_delegate_segments=True, - passes=[ - # If there are Linear operations left in the graph, let's execute - # them with the optimized op_linear rather than materializing a - # transpose followed by a regular op_mm. - ConvertToLinearPass(), - QuantFusionPass(), - ], + passes=to_executorch_passes, memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(), ) diff --git a/extension/llm/modules/test/test_attention.py b/extension/llm/modules/test/test_attention.py index 82ee1febf49..6cd05b4bf65 100644 --- a/extension/llm/modules/test/test_attention.py +++ b/extension/llm/modules/test/test_attention.py @@ -11,6 +11,8 @@ import torch from executorch.exir import EdgeCompileConfig, to_edge +from executorch.exir.capture._config import ExecutorchBackendConfig +from executorch.exir.passes.init_mutable_pass import InitializedMutableBufferPass from executorch.extension.llm.modules.attention import ( MultiHeadAttention as ETMultiHeadAttention, ) @@ -114,7 +116,7 @@ def test_attention_eager(self): et_res = self.et_mha(self.x, self.x) # Self attention. tt_res = self.tt_mha(self.x, self.x) # Self attention. - self.assertTrue(torch.allclose(et_res, tt_res)) + assert_close(et_res, tt_res) self.et_mha.reset_cache() self.tt_mha.reset_cache() @@ -125,7 +127,7 @@ def test_attention_eager(self): self.x, self.x, input_pos=self.input_pos ) # Self attention with input pos. - self.assertTrue(torch.allclose(et_res, tt_res)) + assert_close(et_res, tt_res) # test kv cache read. Input pos can be [10, 11, ..., 19] next_input_pos = torch.arange(10, 20).unsqueeze(0) @@ -187,9 +189,8 @@ def test_attention_aoti(self): def test_attention_executorch(self): # Self attention. - # TODO: Fix kv cache - # self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100) - # self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100) + self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100) + self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100) with torch.no_grad(): et_mha_ep = torch.export.export( @@ -202,9 +203,15 @@ def test_attention_executorch(self): et_program = to_edge( et_mha_ep, compile_config=EdgeCompileConfig( - _core_aten_ops_exception_list=[torch.ops.aten._assert_async.msg] + _core_aten_ops_exception_list=[torch.ops.aten._assert_async.msg], + _check_ir_validity=False, ), - ).to_executorch() + ).to_executorch( + config=ExecutorchBackendConfig( + passes=[InitializedMutableBufferPass(["cache_pos"])], + ) + ) + runtime = Runtime.get() program = runtime.load_program(et_program.buffer) method = program.load_method("forward") @@ -219,9 +226,8 @@ def test_attention_torch_cond_eager(self): self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len) self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len) - # mask mask = self.causal_mask[self.input_pos, :] - # First run + # First run. et_res = self.et_mha( self.x, self.x, mask=mask, input_pos=self.input_pos ) # Self attention with input pos. @@ -229,18 +235,14 @@ def test_attention_torch_cond_eager(self): self.x, self.x, mask=mask, input_pos=self.input_pos ) # Self attention with input pos. - self.assertTrue(torch.allclose(et_res, tt_res)) + assert_close(et_res, tt_res) # Second run test kv cache read. Input pos is [10, 11, ..., 19] next_input_pos = torch.arange(10, 20).unsqueeze(0) empty_y = torch.full_like(self.x, torch.nan) mask = self.causal_mask[next_input_pos, :] - et_res = self.et_mha( - self.x, empty_y, mask=mask, input_pos=next_input_pos - ) # Self attention with input pos. - tt_res = self.tt_mha( - self.x, None, mask=mask, input_pos=next_input_pos - ) # Self attention with input pos. + et_res = self.et_mha(self.x, empty_y, mask=mask, input_pos=next_input_pos) + tt_res = self.tt_mha(self.x, None, mask=mask, input_pos=next_input_pos) assert_close(et_res, tt_res) diff --git a/extension/llm/modules/test/test_kv_cache.py b/extension/llm/modules/test/test_kv_cache.py new file mode 100644 index 00000000000..6029a038825 --- /dev/null +++ b/extension/llm/modules/test/test_kv_cache.py @@ -0,0 +1,217 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import tempfile +import unittest +from typing import Callable, Tuple + +import torch +from executorch.exir import EdgeCompileConfig, to_edge +from executorch.exir.capture._config import ExecutorchBackendConfig +from executorch.exir.passes.init_mutable_pass import InitializedMutableBufferPass + +from executorch.extension.export_util.utils import save_pte_program +from executorch.extension.llm.modules.kv_cache import KVCache as InferenceKVCache + +from executorch.extension.pybindings.portable_lib import ( + _load_for_executorch_from_buffer, +) +from executorch.runtime import Runtime +from torch.testing import assert_close +from torchtune.modules.kv_cache import KVCache + + +def generate_cache_inputs( + seq_len: int, + batch_size: int = 1, + num_kv_heads: int = 64, + head_dim: int = 8, +) -> Tuple[torch.Tensor, ...]: + """Helper to generate k_val and v_val for both et and tt caches.""" + k_val = torch.ones(batch_size, seq_len, num_kv_heads, head_dim) + v_val = torch.ones(batch_size, seq_len, num_kv_heads, head_dim) + + # For torchtune, the kv cache takes in transposed k and v. + k_val_trans = k_val.transpose(1, 2) + v_val_trans = v_val.transpose(1, 2) + + return (k_val, v_val, k_val_trans, v_val_trans) + + +class KVCacheTest(unittest.TestCase): + def setUp(self): + self.batch_size = 1 + self.max_seq_len = 10 + self.num_kv_heads = 1 # For testing purposes, usually this is 64. + self.head_dim = 8 + self.dtype = torch.float + + self.tt_kv_cache = KVCache( + batch_size=self.batch_size, + max_seq_len=self.max_seq_len, + num_kv_heads=self.num_kv_heads, + head_dim=self.head_dim, + dtype=self.dtype, + ) + self.et_kv_cache = InferenceKVCache( + batch_size=self.batch_size, + max_seq_len=self.max_seq_len, + num_kv_heads=self.num_kv_heads, + head_dim=self.head_dim, + dtype=self.dtype, + transpose_cache=False, + ) + + def _test_kv_cache(self, et_cache_module: Callable): + """ + Given an executorch kv cache anywhere along the export chain, compare it's results + against torchtune and run basic tests. + """ + prefill_seq_len = 3 + k_val, v_val, k_val_trans, v_val_trans = generate_cache_inputs( + prefill_seq_len, self.batch_size, self.num_kv_heads, self.head_dim + ) + + et_res = et_cache_module(k_val, v_val) + tt_res = self.tt_kv_cache.update(k_val_trans, v_val_trans) + tt_res_transposed = (tt_res[0].transpose(1, 2), tt_res[1].transpose(1, 2)) + + # Check torchtune matches executorch. + assert_close(et_res, tt_res_transposed) + + # Check the values are correct, all rows in the seq_len dim should be + # filled with 1s up to and including the 3rd. + et_k_cache = et_res[0] + for i in range(prefill_seq_len): + self.assertTrue(et_k_cache[0][i][0][0] == 1) + self.assertTrue(et_k_cache[0][prefill_seq_len][0][0] == 0) + + """Case 2: Token-by-token (seq_len = 0)""" + seq_len = 1 + k_val, v_val, k_val_trans, v_val_trans = generate_cache_inputs( + seq_len, self.batch_size, self.num_kv_heads, self.head_dim + ) + + et_res = et_cache_module(k_val, v_val) + tt_res = self.tt_kv_cache.update(k_val_trans, v_val_trans) + tt_res_transposed = (tt_res[0].transpose(1, 2), tt_res[1].transpose(1, 2)) + + # Check torchtune matches executorch. + assert_close(tt_res_transposed, et_res) + + # All rows should be filled with 1s up to 3 + 1th row. + et_k_cache = et_res[0] + for i in range(prefill_seq_len + 1): + self.assertTrue(et_k_cache[0][i][0][0] == 1) + + self.assertTrue(et_k_cache[0][prefill_seq_len + 1][0][0] == 0) + + def export_kv_cache( + self, + kv_cache: torch.nn.Module, + ) -> torch.export.ExportedProgram: + # Wrapper since torch.export only exports forward(). + class EtCacheWrapper(torch.nn.Module): + def __init__(self, kv_cache: torch.nn.Module): + super().__init__() + self.kv_cache = kv_cache + + def forward(self, k_val: torch.Tensor, v_val: torch.Tensor): + return self.kv_cache.update(k_val, v_val) + + dim = torch.export.Dim("seq_len_dim", min=1, max=self.max_seq_len) + exported_kv_cache = torch.export.export( + EtCacheWrapper(self.et_kv_cache), + ( + torch.Tensor(self.batch_size, 3, self.num_kv_heads, self.head_dim), + torch.Tensor(self.batch_size, 3, self.num_kv_heads, self.head_dim), + ), # 3 as example prefill seq_len. + dynamic_shapes={ + "k_val": { + 0: torch.export.Dim.STATIC, + 1: dim, + 2: torch.export.Dim.STATIC, + 3: torch.export.Dim.STATIC, + }, + "v_val": { + 0: torch.export.Dim.STATIC, + 1: dim, + 2: torch.export.Dim.STATIC, + 3: torch.export.Dim.STATIC, + }, + }, + ) + return exported_kv_cache + + def test_kv_cache_eager(self): + self._test_kv_cache(self.et_kv_cache.update) + + def test_kv_cache_export(self): + exported_kv_cache = self.export_kv_cache(self.et_kv_cache) + self._test_kv_cache(exported_kv_cache.module()) + + def test_kv_cache_edge(self): + exported_kv_cache = self.export_kv_cache(self.et_kv_cache) + edge_program = to_edge( + exported_kv_cache, + compile_config=EdgeCompileConfig( + _core_aten_ops_exception_list=[torch.ops.aten._assert_async.msg], + _check_ir_validity=False, + ), + ) + self._test_kv_cache(edge_program._edge_programs["forward"].module()) + + def test_kv_cache_executorch(self): + exported_kv_cache = self.export_kv_cache(self.et_kv_cache) + edge_program = to_edge( + exported_kv_cache, + compile_config=EdgeCompileConfig( + _core_aten_ops_exception_list=[torch.ops.aten._assert_async.msg], + _check_ir_validity=False, + ), + ) + et_config = ExecutorchBackendConfig( + passes=[InitializedMutableBufferPass(["cache_pos"])], + ) + et_program = edge_program.to_executorch(config=et_config) + + runtime = Runtime.get() + program = runtime.load_program(et_program.buffer) + method = program.load_method("forward") + + # Since method.execute expects a tuple of args. + def wrapped_callable(k_val: torch.Tensor, v_val: torch.Tensor) -> torch.Tensor: + return method.execute((k_val, v_val)) + + self._test_kv_cache(wrapped_callable) + + def test_kv_cache_executorch_from_file(self): + exported_kv_cache = self.export_kv_cache(self.et_kv_cache) + edge_program = to_edge( + exported_kv_cache, + compile_config=EdgeCompileConfig( + _core_aten_ops_exception_list=[torch.ops.aten._assert_async.msg], + _check_ir_validity=False, + ), + ) + et_config = ExecutorchBackendConfig( + passes=[InitializedMutableBufferPass(["cache_pos"])], + ) + et_program = edge_program.to_executorch(config=et_config) + + with tempfile.TemporaryDirectory() as tempdir: + pte_path = save_pte_program(et_program, "test_et_kv_cache", tempdir) + with open(pte_path, "rb") as f: + model_bytes = f.read() + loaded_et_program = _load_for_executorch_from_buffer(model_bytes) + + # Since method.execute expects a tuple of args. + def wrapped_callable( + k_val: torch.Tensor, v_val: torch.Tensor + ) -> torch.Tensor: + return loaded_et_program.forward((k_val, v_val)) + + self._test_kv_cache(wrapped_callable)