From 7e997a441777727247757170e2a6f4e3fa7d419f Mon Sep 17 00:00:00 2001 From: lucylq Date: Mon, 5 May 2025 14:01:54 -0700 Subject: [PATCH] Refactor attention v2 Pull Request resolved: https://github.com/pytorch/executorch/pull/10623 Pull attention creation out of Transformer/TransformerBlock. Instead, pass the layers into Transformer. The motivation is to customize linear layers in attention for LoRA (eg. make wq into a LoraLinear instead of a regular linear). In the next diff (D73517350), we pull wq,wk,wv,wo out of the attention and pass those in as well. This allows us to customize attention parameters without passing in ModelArgs and doing the customization deep inside attention.py. I think this modularizes our attention/transformer components, though also means that users have to do some more work to construct the attention layers and pass it to transformer. It follows the torchtune structure more closely, eg. https://github.com/pytorch/torchtune/blob/main/torchtune/models/llama3_2/_component_builders.py#L221 Previously here: D73474110 ghstack-source-id: 282118266 @exported-using-ghexport Differential Revision: [D73538697](https://our.internmc.facebook.com/intern/diff/D73538697/) --- examples/models/llama/llama_transformer.py | 73 +++++++++++++++---- examples/models/llama/model.py | 5 +- .../tests/test_pre_quantization_transforms.py | 7 +- .../llama/tests/test_static_attention.py | 6 +- examples/models/llava/model.py | 4 +- 5 files changed, 73 insertions(+), 22 deletions(-) diff --git a/examples/models/llama/llama_transformer.py b/examples/models/llama/llama_transformer.py index 5c8db7f208d..1fdcdcd91fc 100644 --- a/examples/models/llama/llama_transformer.py +++ b/examples/models/llama/llama_transformer.py @@ -13,6 +13,7 @@ import torch.nn.functional as F from executorch.examples.models.llama.attention import ( + Attention, ATTENTION_REGISTRY, ForwardOptions, ) @@ -83,19 +84,21 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class TransformerBlock(nn.Module): - def __init__(self, layer_id: int, args: ModelArgs, rope: Rope): + def __init__(self, args: ModelArgs, attention: Attention): + """ + Transformer block with support for pre-norm and post-norm. + Args: + args (ModelArgs): model configuration parameters. + attention (Attention): attention object to use in the transformer + block. See `attention.py` for types of attention. Make sure + the attention type is registered in the ATTENTION_REGISTRY. + """ super().__init__() self.use_kv_cache = args.use_kv_cache self.n_heads = args.n_heads self.dim = args.dim self.head_dim = args.head_dim - if args.attention_type not in ATTENTION_REGISTRY: - raise ValueError( - f"Unknown attention type: {args.attention_type}. " - f"Available: {list(ATTENTION_REGISTRY.keys())}" - ) - cls = ATTENTION_REGISTRY[args.attention_type] - self.attention = cls(args, layer_id, rope) + self.attention = attention if args.moe: self.block_sparse_moe = MOEFeedForward(args) else: @@ -103,6 +106,24 @@ def __init__(self, layer_id: int, args: ModelArgs, rope: Rope): self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) + @classmethod + def from_type(cls, layer_id, args, rope) -> "TransformerBlock": + """ + Create a TransformerBlock with the legacy constructor. + Args: + layer_id (int): the index of the layer. + args (ModelArgs): model configuration parameters. + rope (Rope): the rope object to use for rotary embeddings. + """ + if args.attention_type not in ATTENTION_REGISTRY: + raise ValueError( + f"Unknown attention type: {args.attention_type}. " + f"Available: {list(ATTENTION_REGISTRY.keys())}" + ) + cls = ATTENTION_REGISTRY[args.attention_type] + attention = cls(args, layer_id, rope) + return TransformerBlock(args, attention) + def forward(self, x, freqs_cos, freqs_sin, attn_options: ForwardOptions): # x: 1xN h, attn_options_update = self.attention.forward( self.attention_norm(x), freqs_cos, freqs_sin, **attn_options @@ -117,7 +138,15 @@ def forward(self, x, freqs_cos, freqs_sin, attn_options: ForwardOptions): # x: class Transformer(nn.Module): - def __init__(self, params: ModelArgs): + def __init__(self, params: ModelArgs, layers: nn.ModuleList, rope: Rope): + """ + Transformer model. + Args: + params (ModelArgs): model configuration parameters. + layers (nn.ModuleList): list of transformer blocks - see the + `TransformerBlock` type above. + rope (Rope): the rope object to use for rotary embeddings. + """ super().__init__() self.params = params self.vocab_size = params.vocab_size @@ -130,10 +159,8 @@ def __init__(self, params: ModelArgs): if self.apply_embedding else None ) - self.rope = Rope(params) - self.layers = torch.nn.ModuleList() - for layer_id in range(params.n_layers): - self.layers.append(TransformerBlock(layer_id, params, self.rope)) + self.layers = layers + self.rope = rope self.norm = RMSNorm(params.dim, eps=params.norm_eps) self.output = ( nn.Linear(params.dim, params.vocab_size, bias=False) @@ -212,3 +239,23 @@ def forward( return logits, attn_options_update return logits + + +def construct_transformer(model_args: ModelArgs) -> Transformer: + """ + Construct a Transformer model from the given model arguments. + """ + rope = Rope(model_args) + if model_args.attention_type not in ATTENTION_REGISTRY: + raise ValueError( + f"Unknown attention type: {model_args.attention_type}. " + f"Available: {list(ATTENTION_REGISTRY.keys())}" + ) + layers = torch.nn.ModuleList() + cls = ATTENTION_REGISTRY[model_args.attention_type] + for layer_id in range(model_args.n_layers): + attention = cls(model_args, layer_id, rope) + transformer_block = TransformerBlock(model_args, attention) + layers.append(transformer_block) + + return Transformer(model_args, layers, rope) diff --git a/examples/models/llama/model.py b/examples/models/llama/model.py index 2c82841c573..d6400c29db8 100644 --- a/examples/models/llama/model.py +++ b/examples/models/llama/model.py @@ -15,9 +15,10 @@ get_checkpoint_dtype, get_default_model_resource_dir, ) -from executorch.examples.models.llama.llama_transformer import Transformer +from executorch.examples.models.llama.llama_transformer import construct_transformer from executorch.examples.models.llama.model_args import ModelArgs +from executorch.examples.models.llama.rope import Rope from torchao.utils import TorchAOBaseTensor try: @@ -174,7 +175,7 @@ def __init__(self, **kwargs): # They possess all other metadata a tensor carries such as size, stride, requires_grad. with torch.device("meta"): # Model itself is loaded in default dtype, fp32. - self.model_ = Transformer(model_args) + self.model_ = construct_transformer(model_args) # Get checkpoint dtype. if checkpoint: self.model_.checkpoint_dtype = get_checkpoint_dtype(checkpoint) diff --git a/examples/models/llama/tests/test_pre_quantization_transforms.py b/examples/models/llama/tests/test_pre_quantization_transforms.py index 345f3fad9ba..dc1f9c6cd71 100644 --- a/examples/models/llama/tests/test_pre_quantization_transforms.py +++ b/examples/models/llama/tests/test_pre_quantization_transforms.py @@ -7,7 +7,10 @@ import unittest import torch -from executorch.examples.models.llama.llama_transformer import Transformer +from executorch.examples.models.llama.llama_transformer import ( + construct_transformer, + Transformer, +) from executorch.examples.models.llama.model_args import ModelArgs from executorch.examples.models.llama.source_transformation.pre_quantization import ( sanitize_checkpoint_from_pre_quantization, @@ -39,7 +42,7 @@ def _prepare_dummy_model(self) -> Transformer: vocab_size=32000, ) - model = Transformer(model_args) + model = construct_transformer(model_args) return model diff --git a/examples/models/llama/tests/test_static_attention.py b/examples/models/llama/tests/test_static_attention.py index a1b6742416e..77b8be5d401 100644 --- a/examples/models/llama/tests/test_static_attention.py +++ b/examples/models/llama/tests/test_static_attention.py @@ -2,7 +2,7 @@ import torch from executorch.examples.models.llama.attention import AttentionMHA, ForwardOptions -from executorch.examples.models.llama.llama_transformer import Transformer +from executorch.examples.models.llama.llama_transformer import construct_transformer from executorch.examples.models.llama.model_args import ModelArgs from executorch.examples.models.llama.rope import Rope from executorch.examples.models.llama.static_attention import ( @@ -160,10 +160,10 @@ def test_within_transformer(self): n_layers=4, vocab_size=128, ) - mha_transformer = Transformer(config).eval() + mha_transformer = construct_transformer(config).eval() config.attention_type = "static" - static_transformer = Transformer(config).eval() + static_transformer = construct_transformer(config).eval() static_transformer.load_state_dict(mha_transformer.state_dict(), strict=False) for mha_layer, static_layer in zip( mha_transformer.layers, static_transformer.layers diff --git a/examples/models/llava/model.py b/examples/models/llava/model.py index 351356607c8..7bcf560536c 100644 --- a/examples/models/llava/model.py +++ b/examples/models/llava/model.py @@ -12,7 +12,7 @@ import requests import torch -from executorch.examples.models.llama.llama_transformer import Transformer +from executorch.examples.models.llama.llama_transformer import construct_transformer from executorch.examples.models.llama.model_args import ModelArgs from executorch.examples.models.llama.source_transformation.custom_kv_cache import ( @@ -66,7 +66,7 @@ def __init__( use_hf_rope=True, max_seq_len=max_seq_len, ) - self.text_model = Transformer(self.text_model_args) + self.text_model = construct_transformer(self.text_model_args) # use custom op for SDPA. if use_sdpa_with_kv_cache_op: self.text_model = replace_kv_cache_with_custom_kv_cache(self.text_model)