diff --git a/extension/llm/modules/README.md b/extension/llm/modules/README.md index 3694f8b1551..e6e1a20cecf 100644 --- a/extension/llm/modules/README.md +++ b/extension/llm/modules/README.md @@ -1,14 +1,17 @@ -## Export Friendly Modules +## Export-friendly Modules -Modules in this directory are: -* Extending `torch.nn.Module`. -* Guranteed to work out of the box with `torch.export.export()` and `torch.aot_compile()`. -* Guranteed to be able to work with ExecuTorch. +Modules in this directory: +* Extend `torch.nn.Module`. +* Are guaranteed to work out of the box with `torch.export.export()`. +* Should work out of the box with `torch.aot_compile()`. +* Should be able to workt with ExecuTorch. All modules should be covered by unit tests to make sure they are: -1. giving the same output as the reference implementation in PyTorch or torchtune -2. export friendly -3. AOTI friendly -4. ExecuTorch friendly +1. Give the output as the reference eager model in PyTorch or TorrchTune +2. Export-friendly -Notice that these modules are subject to change (may upstream to torchtune) so proceed with caution. +Additionally, we aim to make these modules: +3. AOTI-friendly +4. ExecuTorch-friendly + +These modules are subject to change (may upstream to TorchTune) so proceed with caution. diff --git a/extension/llm/modules/mha.py b/extension/llm/modules/mha.py new file mode 100644 index 00000000000..0bfa4eb20ce --- /dev/null +++ b/extension/llm/modules/mha.py @@ -0,0 +1,404 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from typing import Optional + +import torch +import torchtune.modules.attention as TorchTuneAttention +from torch import nn +from torchtune.modules.attention_utils import _MaskType, _sdpa_or_flex_attention +from torchtune.modules.kv_cache import KVCache + +logger = logging.getLogger(__name__) + + +class MultiHeadAttention(nn.Module): + """ + NOTE: copied from Torchtune's mha.py. Should be mostly 1:1 except + that SDPA is factored out so that it can be swapped for more + efficient ExecuTorch-defined SDPA ops. + + Multi-headed attention layer with support for grouped query + attention (GQA) introduced in https://arxiv.org/abs/2305.13245v1. + + GQA is a version of multiheaded attention (MHA) which uses fewer + key/value heads than query heads by grouping n query heads for each + key and value head. Multi-Query Attention is an extreme + version where we have a single key and value head shared by all + query heads. + + Following is an example of MHA, GQA and MQA with num_heads = 4 + + (credit for the documentation: + `litgpt.Config `_). + + + :: + + ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐ + │ v ││ v ││ v ││ v │ │ v │ │ v │ │ v │ + └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘ + │ │ │ │ │ │ │ + ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐ + │ k ││ k ││ k ││ k │ │ k │ │ k │ │ k │ + └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘ + │ │ │ │ ┌──┴──┐ ┌──┴──┐ ┌────┬──┴─┬────┐ + ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐ + │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │ + └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘ + ◀──────────────────▶ ◀──────────────────▶ ◀──────────────────▶ + MHA GQA MQA + n_kv_heads =4 n_kv_heads=2 n_kv_heads=1 + + Args: + embed_dim (int): embedding dimension for the model + num_heads (int): number of query heads. For MHA this is also the + number of heads for key and value + num_kv_heads (int): number of key and value heads. User should ensure + ``num_heads % num_kv_heads == 0``. For standard MHA set ``num_kv_heads == num_heads``, + for GQA ``num_kv_heads < num_heads``, and for MQA set ``num_kv_heads == 1``. + head_dim (int): dimension of each head, calculated by ``embed_dim // num_heads``. + q_proj (nn.Module): projection layer for query. + k_proj (nn.Module): projection layer for key. + v_proj (nn.Module): projection layer for value. + output_proj (nn.Module): projection layer for output. + pos_embeddings (Optional[nn.Module]): positional embeddings layer, e.g. RotaryPositionalEmbeddings. + q_norm (Optional[nn.Module]): normalization layer for query, e.g. RMSNorm. For decoding, this is applied + before updating from kv_cache. This means it will only support token wide normalization and not + batch or sequence wide normalization. + k_norm (Optional[nn.Module]): normalization layer for key, must be set if q_norm is. + kv_cache (Optional[KVCache]): KVCache object used to cache key and value + max_seq_len (int): maximum sequence length supported by the model. + This is needed to compute the RoPE Cache. Default: 4096. + is_causal (bool): sets the default mask to causal when no mask is provided + attn_dropout (float): dropout value passed onto the scaled_dot_product_attention function. + Default value is 0.0. + + Raises: + ValueError: If ``num_heads % num_kv_heads != 0`` + ValueError: If ``embed_dim % num_heads != 0`` + ValueError: If ``attn_dropout < 0`` or ``attn_dropout > 1`` + ValueError: if q_norm is defined without k_norm or vice versa + """ + + def __init__( + self, + *, + embed_dim: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + q_proj: nn.Module, + k_proj: nn.Module, + v_proj: nn.Module, + output_proj: nn.Module, + pos_embeddings: Optional[nn.Module] = None, + q_norm: Optional[nn.Module] = None, + k_norm: Optional[nn.Module] = None, + kv_cache: Optional[KVCache] = None, + max_seq_len: int = 4096, + is_causal: bool = True, + attn_dropout: float = 0.0, + ) -> None: + super().__init__() + if num_heads % num_kv_heads != 0: + raise ValueError( + f"num_heads ({num_heads}) must be divisible by " + f"num_kv_heads ({num_kv_heads})" + ) + + if embed_dim % num_heads != 0: + raise ValueError( + f"embed_dim ({embed_dim}) must be divisible by " + f"num_heads ({num_heads})" + ) + + if attn_dropout < 0 or attn_dropout > 1: + raise ValueError(f"attn_dropout ({embed_dim}) must be between 0.0 and 1.0") + + if bool(q_norm) ^ bool(k_norm): + raise ValueError("q and k norm must be set together") + + # Set attributes + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.embed_dim = embed_dim + self.attn_dropout = attn_dropout + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.is_causal = is_causal + + # Set layers + self.kv_cache = kv_cache + self.q_proj = q_proj + self.k_proj = k_proj + self.v_proj = v_proj + self.output_proj = output_proj + self.q_norm = q_norm + self.k_norm = k_norm + self.pos_embeddings = pos_embeddings + + # Use flex attention if supported and we are sample packing + self._attention_call = _sdpa_or_flex_attention() + self._sdpa = SDPA( + num_kv_heads=self.num_kv_heads, + num_heads=self.num_heads, + head_dim=self.head_dim, + q_per_kv=self.num_heads // self.num_kv_heads, + attn_dropout=self.attn_dropout if self.training else 0.0, + is_causal=self.is_causal, + attention_fn=self._attention_call, + kv_cache=self.kv_cache, + ) + + # this flag indicates whether to update the kv-cache during forward + # passes. when disabled, we can have the cache setup but still + # perform normal forward passes + self.cache_enabled = False + + def setup_cache( + self, batch_size: int, dtype: torch.dtype, max_seq_len: int + ) -> None: + """Setup key value caches for attention calculation. If called + after kv_cache is already setup, this will be skipped. + + Args: + batch_size (int): batch size for the caches. + dtype (torch.dtype): dtype for the caches. + max_seq_len (int): maximum sequence length model will be run with. + """ + # Don't overwrite user defined kv_cache from init + if self.kv_cache is not None: + logger.warning( + "Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping." + ) + else: + self.kv_cache = KVCache( + batch_size=batch_size, + max_seq_len=max_seq_len, + num_heads=self.num_heads, + head_dim=self.head_dim, + dtype=dtype, + ) + self._sdpa.kv_cache = self.kv_cache + self.cache_enabled = True + + def reset_cache(self): + """Reset the key value caches.""" + if self.kv_cache is None: + raise RuntimeError( + "Key value caches are not setup. Call ``setup_caches()`` first." + ) + self.kv_cache.reset() + + def forward( + self, + x: torch.Tensor, + y: Optional[torch.Tensor] = None, + *, + mask: Optional[_MaskType] = None, + input_pos: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Args: + x (torch.Tensor): input tensor with shape [b x s_x x d] for the query + y (Optional[torch.Tensor]): second input tensor with shape [b x s_y x d], is the input + for k and v. For self attention, x=y. Optional only with kv_cache enabled. + mask (Optional[_MaskType]): Used to mask the scores after the query-key multiplication + and before the softmax. Either: + + A boolean tensor with shape ``[b x s x s]``, ``[b x s x self.encoder_max_cache_seq_len]``, + or ``[b x s x self.encoder_max_cache_seq_len]`` if using KV-cacheing with encoder/decoder layers. + A value of True in row ``i`` and column ``j`` means token ``i`` attends to token ``j``. A value of False means + token ``i`` does not attend to token ``j``. If no mask is specified, a causal mask + is used by default. + + A :class:`~torch.nn.attention.flex_attention.BlockMask` for document masking in a packed sequence + created via `create_block_mask `_. We use + :func:`~torch.nn.attention.flex_attention.flex_attention` when computing attention with block masks. + Default is None. + input_pos (Optional[torch.Tensor]): Optional tensor which contains the position ids + of each token. During training, this is used to indicate the positions + of each token relative to its sample when packed, shape [b x s]. + During inference, this indicates the position of the current token. + If none, assume the index of the token is its position id. Default is None. + + Raises: + ValueError: If no ``y`` input and ``kv_cache`` is not enabled. + + Returns: + torch.Tensor: output tensor with attention applied + + Notation used for tensor shapes: + - b: batch size + - s_x: sequence length for x + - s_y: sequence length for y + - n_h: num heads + - n_kv: num kv heads + - d: embed dim + - h_d: head dim + """ + # x has shape [b, s_x, d] + # y has shape [b, s_y, d] + b, s_x, _ = x.shape + s_y = y.shape[1] if y is not None else 0 + + # q has shape [b, s_x, num_heads * head_dim] + q = self.q_proj(x) + + # number of queries per key/value + q_per_kv = self.num_heads // self.num_kv_heads + q = q.view(b, s_x, self.num_kv_heads * q_per_kv, self.head_dim) + + # Apply positional embeddings + if self.pos_embeddings is not None: + q = self.pos_embeddings(q, input_pos=input_pos) + + # Normalize q + if self.q_norm is not None: + q = self.q_norm(q) + + if y is None: + if self.kv_cache is None: + raise ValueError( + "Must provide y input or use kv_cache to enable streaming decoding" + ) + k = self.kv_cache.k_cache + v = self.kv_cache.v_cache + else: + # Update k and v shape, positional embeddings, and normalization + + # k has shape [b, s_y, num_kv_heads * head_dim] + # v has shape [b, s_y, num_kv_heads * head_dim] + k = self.k_proj(y) + v = self.v_proj(y) + + # Apply positional embeddings + # k: [b, s_y, n_kv, h_d] + k = k.view(b, s_y, -1, self.head_dim) + v = v.view(b, s_y, -1, self.head_dim) + if self.pos_embeddings is not None: + k = self.pos_embeddings(k, input_pos=input_pos) + + # Normalize k + if self.k_norm is not None: + k = self.k_norm(k) + + # Update key-value cache + if self.kv_cache is not None and self.cache_enabled: + k, v = self.kv_cache.update(k, v) + + output = self._sdpa(q, k, v, b, s_x) + return self.output_proj(output) + + +class SDPA(nn.Module): + """ + TorchTune's SDPA which can be optimized and can be swapped + out for a more efficient implementations. + """ + + def __init__( + self, + num_kv_heads: int, + num_heads: int, + head_dim: int, + q_per_kv: int, + attn_dropout: float, + is_causal: bool, + attention_fn, + kv_cache, + ) -> None: + super().__init__() + self.num_kv_heads = num_kv_heads + self.num_heads = num_heads + self.head_dim = head_dim + self.q_per_kv = q_per_kv + self.attn_dropout = attn_dropout + self.is_causal = is_causal + self._attention_fn = attention_fn + self.kv_cache = kv_cache + + def forward( + self, + q: torch.Tensor, # [b, s, n_h, h_d] + k: torch.Tensor, # [b, s, n_kv, h_d] + v: torch.Tensor, # [b, s, n_kv, h_d] + bsz: int, + seq_len: int, + mask: torch.Tensor = None, + ) -> torch.Tensor: + # View + expand + reshape bring num_kv_heads to num_heads for k and v + # to match q. + + # k: [bsz, seq_len, n_kv, 1, h_d] + # v: [bsz, seq_len, n_kv, 1, h_d] + k = k.view(bsz, seq_len, self.num_kv_heads, 1, self.head_dim) + v = v.view(bsz, seq_len, self.num_kv_heads, 1, self.head_dim) + + # Expand the key and value tensors to have the same shape + # as the query tensor by copying values across the relevant dim + if self.num_heads != self.num_kv_heads: + k = k.expand(bsz, seq_len, self.num_kv_heads, self.q_per_kv, self.head_dim) + v = v.expand(bsz, seq_len, self.num_kv_heads, self.q_per_kv, self.head_dim) + + # [bsz, s, n_h, h_d] + k = k.reshape(bsz, seq_len, -1, self.head_dim) + v = v.reshape(bsz, seq_len, -1, self.head_dim) + + # [bsz, n_h, s, h_d] + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + output = self._attention_fn( + q, + k, + v, + mask=mask, + dropout_p=self.attn_dropout, + is_causal=self.kv_cache is None and mask is None and self.is_causal, + ) + # Reshape the output to be the same shape as the input + return output.transpose(1, 2).contiguous().view(bsz, seq_len, -1) + + +def _replace_mha_with_inference_mha(module: torch.nn.Module) -> None: + for name, child in module.named_children(): + if isinstance(child, TorchTuneAttention.MultiHeadAttention): + setattr( + module, + name, + MultiHeadAttention( + embed_dim=child.embed_dim, + num_heads=child.num_heads, + num_kv_heads=child.num_kv_heads, + head_dim=child.head_dim, + q_proj=child.q_proj, + k_proj=child.k_proj, + v_proj=child.v_proj, + output_proj=child.output_proj, + pos_embeddings=child.pos_embeddings, + q_norm=child.q_norm, + k_norm=child.k_norm, + kv_cache=child.kv_cache, + max_seq_len=child.max_seq_len, + is_causal=child.is_causal, + attn_dropout=child.attn_dropout, + ), + ) + else: + replace_mha_with_inference_mha(child) + + +def replace_mha_with_inference_mha(module: torch.nn.Module) -> torch.nn.Module: + """ + Replace TorchTune's MHA with an inference friendly version of MHA that + separates out the inference-related parts for further optimization. + """ + _replace_mha_with_inference_mha(module) + return module diff --git a/extension/llm/modules/test/test_mha.py b/extension/llm/modules/test/test_mha.py new file mode 100644 index 00000000000..0dc7cba6858 --- /dev/null +++ b/extension/llm/modules/test/test_mha.py @@ -0,0 +1,144 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from executorch.exir import EdgeCompileConfig, to_edge + +from executorch.extension.llm.modules.mha import ( + MultiHeadAttention as ETMultiHeadAttention, +) +from executorch.runtime import Runtime +from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE +from torchtune.modules.attention import MultiHeadAttention as TTMultiHeadAttention + + +torch.manual_seed(0) + + +class AttentionTest(unittest.TestCase): + def setUp(self): + super().setUp() + + # Constants + self.embed_dim = 2048 + self.num_heads = 32 + self.num_kv_heads = 8 + self.head_dim = 64 + self.max_seq_len = 128 + self.rope_base = 500_000 + self.scale_factor = 32 + + # Module dependency injections. + self.q_proj = torch.nn.Linear( + self.embed_dim, self.num_heads * self.head_dim, bias=False + ) + self.k_proj = torch.nn.Linear( + self.embed_dim, self.num_kv_heads * self.head_dim, bias=False + ) + self.v_proj = torch.nn.Linear( + self.embed_dim, self.num_kv_heads * self.head_dim, bias=False + ) + self.output_proj = torch.nn.Linear(self.embed_dim, self.embed_dim, bias=False) + self.pos_embeddings = Llama3ScaledRoPE( + dim=self.head_dim, + max_seq_len=self.max_seq_len, + base=self.rope_base, + scale_factor=self.scale_factor, + ) + + # Original TorchTune reference module to test accuracy against. + self.tt_mha = TTMultiHeadAttention( + embed_dim=self.embed_dim, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + head_dim=self.head_dim, + q_proj=self.q_proj, + k_proj=self.k_proj, + v_proj=self.v_proj, + output_proj=self.output_proj, + pos_embeddings=self.pos_embeddings, + max_seq_len=self.max_seq_len, + ) + + # Source transformed module that we are testing. + self.et_mha = ETMultiHeadAttention( + embed_dim=self.embed_dim, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + head_dim=self.head_dim, + q_proj=self.q_proj, + k_proj=self.k_proj, + v_proj=self.v_proj, + output_proj=self.output_proj, + pos_embeddings=self.pos_embeddings, + max_seq_len=self.max_seq_len, + ) + + # Common inputs. + seq_len = 10 + self.x = torch.randn(1, seq_len, self.embed_dim) + seq_len_dim = torch.export.Dim("seq_len", min=1, max=100) + self.dynamic_shapes = ( + {0: torch.export.Dim.STATIC, 1: seq_len_dim, 2: torch.export.Dim.STATIC}, + {0: torch.export.Dim.STATIC, 1: seq_len_dim, 2: torch.export.Dim.STATIC}, + ) + + def test_attention_eager(self): + et_res = self.et_mha(self.x, self.x) # Self attention. + tt_res = self.tt_mha(self.x, self.x) # Self attention. + + self.assertTrue(torch.allclose(et_res, tt_res)) + + # TODO: KV cache. + # self.et_mha.setup_cache(1, dtype=torch.float16, max_seq_len=20) + # self.tt_mha.setup_cache(1, dtype=torch.float16, max_seq_len=20) + + # et_res = self.et_mha(self.x, self.x) # Self attention. + # tt_res = self.tt_mha(self.x, self.x) # Self attention. + + # self.assertTrue(torch.allclose(et_res, tt_res)) + + def test_attention_export(self): + # Self attention. + et_mha_ep = torch.export.export( + self.et_mha, + (self.x, self.x), + kwargs=None, + dynamic_shapes=self.dynamic_shapes, + ) + et_res = et_mha_ep.module()(self.x, self.x) + tt_res = self.tt_mha(self.x, self.x) + self.assertTrue(torch.allclose(et_res, tt_res)) + + # TODO: KV cache. + + def test_attention_aoti(self): + # TODO. + pass + + def test_attention_executorch(self): + # Self attention. + et_mha_ep = torch.export.export( + self.et_mha, + (self.x, self.x), + kwargs=None, + dynamic_shapes=self.dynamic_shapes, + ) + et_program = to_edge( + et_mha_ep, + compile_config=EdgeCompileConfig(), + ).to_executorch() + runtime = Runtime.get() + program = runtime.load_program(et_program.buffer) + method = program.load_method("forward") + et_res = method.execute((self.x, self.x)) + tt_res = self.tt_mha(self.x, self.x) + + self.assertTrue(torch.allclose(et_res[0], tt_res, atol=1e-06)) + + # TODO: KV cache.