diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index ea4296cc52c..e267f5d71de 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -23,6 +23,7 @@ import torch from executorch.devtools.etrecord import generate_etrecord +from executorch.exir.passes.init_mutable_pass import InitializedMutableBufferPass from executorch.extension.llm.export.builder import DType, LLMEdgeManager @@ -760,6 +761,9 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901 for partitioner in partitioners: logging.info(f"--> {partitioner.__class__.__name__}") + additional_passes = [] + if args.model in TORCHTUNE_DEFINED_MODELS: + additional_passes = [InitializedMutableBufferPass(["cache_pos"])] if args.generate_etrecord: if not builder_exported_to_edge.edge_manager: raise ValueError("Unable to generate etrecord due to missing edge manager.") @@ -774,7 +778,9 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901 # pyre-fixme[16]: Module `backends` has no attribute `qualcomm`. canonicalize_program(builder.edge_manager.exported_program()) - builder = builder.to_executorch() + builder = builder.to_executorch( + passes=additional_passes, + ) # Generate ETRecord if edge_manager_copy: @@ -792,7 +798,7 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901 # pyre-fixme[16]: Module `backends` has no attribute `qualcomm`. canonicalize_program(builder.edge_manager.exported_program()) - builder = builder.to_executorch() + builder = builder.to_executorch(passes=additional_passes) if args.profile_memory: generate_memory_trace(builder.export_program, "memory_profile.json") diff --git a/examples/models/llama3_2_vision/runner/native.py b/examples/models/llama3_2_vision/runner/native.py index 9a28c94f9c2..2b4d709f9b4 100644 --- a/examples/models/llama3_2_vision/runner/native.py +++ b/examples/models/llama3_2_vision/runner/native.py @@ -18,13 +18,15 @@ TorchTuneLlamaRunner, ) -from executorch.extension.pybindings.portable_lib import _load_for_executorch +from executorch.extension.pybindings.portable_lib import ( + _load_for_executorch_from_buffer, +) # Load custom ops and quantized ops. from executorch.extension.pybindings import portable_lib # noqa # usort: skip # Note: import this after portable_lib -from executorch.extension.llm.custom_ops import sdpa_with_kv_cache # noqa # usort: skip +from executorch.extension.llm.custom_ops import custom_ops # noqa # usort: skip from executorch.kernels import quantized # noqa @@ -43,7 +45,17 @@ def __init__(self, args): use_kv_cache=args.kv_cache, vocab_size=params["vocab_size"], ) - self.model = _load_for_executorch(args.pte) + # Save the loaded model bytes to prevent data from going out of + # scope after the `with` and getting cleaned up by Python's + # garbage collector. + self.model_bytes = None + with open(args.pte, "rb") as f: + self.model_bytes = f.read() + # Need to use _load_for_executorch_from_buffer instead of + # _load_for_executorch because the latter uses MmapDataLoader, + # which doesn't have load_into() implemented, which is needed + # for loading initialized mutable buffers. + self.model = _load_for_executorch_from_buffer(self.model_bytes) self.use_kv_cache = args.kv_cache def forward( diff --git a/exir/emit/_emitter.py b/exir/emit/_emitter.py index 88247d2a274..1cd286b02fc 100644 --- a/exir/emit/_emitter.py +++ b/exir/emit/_emitter.py @@ -1575,7 +1575,8 @@ def _find_fqn_for_placeholder( warnings.warn( "Mutation on a buffer in the model is detected. ExecuTorch assumes " "buffers that are mutated in the graph have a meaningless initial state, " - "only the shape and dtype will be serialized.", + "only the shape and dtype will be serialized, unless a pass which marks " + "spec.const=True such as InitializedMutableBufferPass is run.", UserWarning, stacklevel=1, ) @@ -1602,6 +1603,7 @@ def placeholder( """ spec = self.node.meta["spec"] constant_tag = self.node.meta.get("constant_tag", None) + initialize_buffer = self.node.meta.get("et_init_buffer", None) is_user_input = True if isinstance(target, str) and isinstance(spec, TensorSpec): @@ -1655,7 +1657,11 @@ def placeholder( spec.storage = real_tensor.untyped_storage() # User inputs and mutable buffers are not constants, other buffers or parameters are. - spec.const = not (is_user_input or is_mutable_buffer) + if initialize_buffer: + assert is_mutable_buffer + spec.const = True + else: + spec.const = not (is_user_input or is_mutable_buffer) evalue = ( self._tensor_spec_to_evalue(spec, constant_tag) diff --git a/exir/passes/init_mutable_pass.py b/exir/passes/init_mutable_pass.py new file mode 100644 index 00000000000..72a67b765a8 --- /dev/null +++ b/exir/passes/init_mutable_pass.py @@ -0,0 +1,31 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import List + +from executorch.exir.pass_base import ExportPass + + +class InitializedMutableBufferPass(ExportPass): + """ + If the buffer has the name "cache_pos", such as in an kv_cache + module with `self.register_buffer("cache_pos", torch.arange(10))`, + mark it with a custom tag which later is used by the emitter to + flag spec.const to True, which provides the mutable buffer with + an initialized state. + """ + + def __init__(self, patterns: List[str]) -> None: + super().__init__() + self.patterns = patterns + + def placeholder(self, name: str, arg, meta): + for pattern in self.patterns: + if pattern in name: + meta["et_init_buffer"] = True + + return super().placeholder(name, arg, meta) diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index ebc7f02ee1a..390e2e47c07 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -25,6 +25,7 @@ from executorch.exir.backend.utils import format_delegated_graph from executorch.exir.capture._config import EdgeCompileConfig, ExecutorchBackendConfig +from executorch.exir.pass_manager import PassType from executorch.exir.passes import MemoryPlanningPass from executorch.exir.passes.quant_fusion_pass import QuantFusionPass from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass @@ -395,21 +396,27 @@ def to_backend(self, partitioners: Optional[List[Partitioner]]) -> "LLMEdgeManag return self - def to_executorch(self) -> "LLMEdgeManager": + def to_executorch( + self, passes: Optional[List[PassType]] = None + ) -> "LLMEdgeManager": """ Lower the model to executorch and get an ExecutorchProgram. """ assert self.edge_manager, "Need to run export_to_edge() first" + to_executorch_passes = [ + # If there are Linear operations left in the graph, let's execute + # them with the optimized op_linear rather than materializing a + # transpose followed by a regular op_mm. + ConvertToLinearPass(), + QuantFusionPass(), + ] + if passes: + to_executorch_passes.extend(passes) + self.export_program = self.edge_manager.to_executorch( ExecutorchBackendConfig( extract_delegate_segments=True, - passes=[ - # If there are Linear operations left in the graph, let's execute - # them with the optimized op_linear rather than materializing a - # transpose followed by a regular op_mm. - ConvertToLinearPass(), - QuantFusionPass(), - ], + passes=to_executorch_passes, memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(), ) diff --git a/extension/llm/modules/attention.py b/extension/llm/modules/attention.py index 60183801b42..695e5efa72b 100644 --- a/extension/llm/modules/attention.py +++ b/extension/llm/modules/attention.py @@ -10,6 +10,7 @@ import torch import torchtune.modules.attention as TorchTuneAttention from executorch.extension.llm.modules.kv_cache import KVCache as InferenceKVCache +from executorch.extension.llm.custom_ops import custom_ops from torch import nn from torchtune.modules.attention_utils import _MaskType, _sdpa_or_flex_attention from torchtune.modules.kv_cache import KVCache @@ -146,6 +147,7 @@ def __init__( # Use flex attention if supported and we are sample packing self._attention_call = _sdpa_or_flex_attention() self._sdpa = SDPA( + max_seq_len=self.max_seq_len, num_kv_heads=self.num_kv_heads, num_heads=self.num_heads, head_dim=self.head_dim, @@ -310,7 +312,7 @@ def false_fn(y): self.kv_cache.v_cache.copy_(v) self.kv_cache.cache_pos.copy_(cache_pos) - output = self._sdpa(q, k, v, b, s_x, mask=mask) + output = self._sdpa(q, k, v, b, s_x, mask=mask, input_pos=input_pos) return self.output_proj(output) @@ -322,6 +324,7 @@ class SDPA(nn.Module): def __init__( self, + max_seq_len: int, num_kv_heads: int, num_heads: int, head_dim: int, @@ -331,6 +334,7 @@ def __init__( kv_cache, ) -> None: super().__init__() + self.max_seq_len = max_seq_len self.num_kv_heads = num_kv_heads self.num_heads = num_heads self.head_dim = head_dim @@ -348,7 +352,23 @@ def forward( bsz: int, seq_len: int, mask: Optional[_MaskType] = None, + # Below args are only used for ET custom sdpa op. + input_pos: Optional[torch.Tensor] = None, ) -> torch.Tensor: + start_pos = input_pos[0][-1].item() - seq_len + 1 + torch._check_is_size(start_pos) + torch._check(start_pos <= self.max_seq_len) + output = torch.ops.llama.custom_sdpa( + q, + k, + v, + start_pos, + None, # Attention mask + 0, # dropout probability. Ignored by the code + True, # is_causal TODO: flip to false if kv cache is enabled??? + ) + return output.view(bsz, seq_len, -1) + # View + expand + reshape bring num_kv_heads to num_heads for k and v # to match q. diff --git a/extension/llm/modules/test/test_kv_cache.py b/extension/llm/modules/test/test_kv_cache.py new file mode 100644 index 00000000000..6029a038825 --- /dev/null +++ b/extension/llm/modules/test/test_kv_cache.py @@ -0,0 +1,217 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import tempfile +import unittest +from typing import Callable, Tuple + +import torch +from executorch.exir import EdgeCompileConfig, to_edge +from executorch.exir.capture._config import ExecutorchBackendConfig +from executorch.exir.passes.init_mutable_pass import InitializedMutableBufferPass + +from executorch.extension.export_util.utils import save_pte_program +from executorch.extension.llm.modules.kv_cache import KVCache as InferenceKVCache + +from executorch.extension.pybindings.portable_lib import ( + _load_for_executorch_from_buffer, +) +from executorch.runtime import Runtime +from torch.testing import assert_close +from torchtune.modules.kv_cache import KVCache + + +def generate_cache_inputs( + seq_len: int, + batch_size: int = 1, + num_kv_heads: int = 64, + head_dim: int = 8, +) -> Tuple[torch.Tensor, ...]: + """Helper to generate k_val and v_val for both et and tt caches.""" + k_val = torch.ones(batch_size, seq_len, num_kv_heads, head_dim) + v_val = torch.ones(batch_size, seq_len, num_kv_heads, head_dim) + + # For torchtune, the kv cache takes in transposed k and v. + k_val_trans = k_val.transpose(1, 2) + v_val_trans = v_val.transpose(1, 2) + + return (k_val, v_val, k_val_trans, v_val_trans) + + +class KVCacheTest(unittest.TestCase): + def setUp(self): + self.batch_size = 1 + self.max_seq_len = 10 + self.num_kv_heads = 1 # For testing purposes, usually this is 64. + self.head_dim = 8 + self.dtype = torch.float + + self.tt_kv_cache = KVCache( + batch_size=self.batch_size, + max_seq_len=self.max_seq_len, + num_kv_heads=self.num_kv_heads, + head_dim=self.head_dim, + dtype=self.dtype, + ) + self.et_kv_cache = InferenceKVCache( + batch_size=self.batch_size, + max_seq_len=self.max_seq_len, + num_kv_heads=self.num_kv_heads, + head_dim=self.head_dim, + dtype=self.dtype, + transpose_cache=False, + ) + + def _test_kv_cache(self, et_cache_module: Callable): + """ + Given an executorch kv cache anywhere along the export chain, compare it's results + against torchtune and run basic tests. + """ + prefill_seq_len = 3 + k_val, v_val, k_val_trans, v_val_trans = generate_cache_inputs( + prefill_seq_len, self.batch_size, self.num_kv_heads, self.head_dim + ) + + et_res = et_cache_module(k_val, v_val) + tt_res = self.tt_kv_cache.update(k_val_trans, v_val_trans) + tt_res_transposed = (tt_res[0].transpose(1, 2), tt_res[1].transpose(1, 2)) + + # Check torchtune matches executorch. + assert_close(et_res, tt_res_transposed) + + # Check the values are correct, all rows in the seq_len dim should be + # filled with 1s up to and including the 3rd. + et_k_cache = et_res[0] + for i in range(prefill_seq_len): + self.assertTrue(et_k_cache[0][i][0][0] == 1) + self.assertTrue(et_k_cache[0][prefill_seq_len][0][0] == 0) + + """Case 2: Token-by-token (seq_len = 0)""" + seq_len = 1 + k_val, v_val, k_val_trans, v_val_trans = generate_cache_inputs( + seq_len, self.batch_size, self.num_kv_heads, self.head_dim + ) + + et_res = et_cache_module(k_val, v_val) + tt_res = self.tt_kv_cache.update(k_val_trans, v_val_trans) + tt_res_transposed = (tt_res[0].transpose(1, 2), tt_res[1].transpose(1, 2)) + + # Check torchtune matches executorch. + assert_close(tt_res_transposed, et_res) + + # All rows should be filled with 1s up to 3 + 1th row. + et_k_cache = et_res[0] + for i in range(prefill_seq_len + 1): + self.assertTrue(et_k_cache[0][i][0][0] == 1) + + self.assertTrue(et_k_cache[0][prefill_seq_len + 1][0][0] == 0) + + def export_kv_cache( + self, + kv_cache: torch.nn.Module, + ) -> torch.export.ExportedProgram: + # Wrapper since torch.export only exports forward(). + class EtCacheWrapper(torch.nn.Module): + def __init__(self, kv_cache: torch.nn.Module): + super().__init__() + self.kv_cache = kv_cache + + def forward(self, k_val: torch.Tensor, v_val: torch.Tensor): + return self.kv_cache.update(k_val, v_val) + + dim = torch.export.Dim("seq_len_dim", min=1, max=self.max_seq_len) + exported_kv_cache = torch.export.export( + EtCacheWrapper(self.et_kv_cache), + ( + torch.Tensor(self.batch_size, 3, self.num_kv_heads, self.head_dim), + torch.Tensor(self.batch_size, 3, self.num_kv_heads, self.head_dim), + ), # 3 as example prefill seq_len. + dynamic_shapes={ + "k_val": { + 0: torch.export.Dim.STATIC, + 1: dim, + 2: torch.export.Dim.STATIC, + 3: torch.export.Dim.STATIC, + }, + "v_val": { + 0: torch.export.Dim.STATIC, + 1: dim, + 2: torch.export.Dim.STATIC, + 3: torch.export.Dim.STATIC, + }, + }, + ) + return exported_kv_cache + + def test_kv_cache_eager(self): + self._test_kv_cache(self.et_kv_cache.update) + + def test_kv_cache_export(self): + exported_kv_cache = self.export_kv_cache(self.et_kv_cache) + self._test_kv_cache(exported_kv_cache.module()) + + def test_kv_cache_edge(self): + exported_kv_cache = self.export_kv_cache(self.et_kv_cache) + edge_program = to_edge( + exported_kv_cache, + compile_config=EdgeCompileConfig( + _core_aten_ops_exception_list=[torch.ops.aten._assert_async.msg], + _check_ir_validity=False, + ), + ) + self._test_kv_cache(edge_program._edge_programs["forward"].module()) + + def test_kv_cache_executorch(self): + exported_kv_cache = self.export_kv_cache(self.et_kv_cache) + edge_program = to_edge( + exported_kv_cache, + compile_config=EdgeCompileConfig( + _core_aten_ops_exception_list=[torch.ops.aten._assert_async.msg], + _check_ir_validity=False, + ), + ) + et_config = ExecutorchBackendConfig( + passes=[InitializedMutableBufferPass(["cache_pos"])], + ) + et_program = edge_program.to_executorch(config=et_config) + + runtime = Runtime.get() + program = runtime.load_program(et_program.buffer) + method = program.load_method("forward") + + # Since method.execute expects a tuple of args. + def wrapped_callable(k_val: torch.Tensor, v_val: torch.Tensor) -> torch.Tensor: + return method.execute((k_val, v_val)) + + self._test_kv_cache(wrapped_callable) + + def test_kv_cache_executorch_from_file(self): + exported_kv_cache = self.export_kv_cache(self.et_kv_cache) + edge_program = to_edge( + exported_kv_cache, + compile_config=EdgeCompileConfig( + _core_aten_ops_exception_list=[torch.ops.aten._assert_async.msg], + _check_ir_validity=False, + ), + ) + et_config = ExecutorchBackendConfig( + passes=[InitializedMutableBufferPass(["cache_pos"])], + ) + et_program = edge_program.to_executorch(config=et_config) + + with tempfile.TemporaryDirectory() as tempdir: + pte_path = save_pte_program(et_program, "test_et_kv_cache", tempdir) + with open(pte_path, "rb") as f: + model_bytes = f.read() + loaded_et_program = _load_for_executorch_from_buffer(model_bytes) + + # Since method.execute expects a tuple of args. + def wrapped_callable( + k_val: torch.Tensor, v_val: torch.Tensor + ) -> torch.Tensor: + return loaded_et_program.forward((k_val, v_val)) + + self._test_kv_cache(wrapped_callable)