From 6375fc2662bfe2e25d1a85d41223df7172a84bdc Mon Sep 17 00:00:00 2001 From: lucylq Date: Thu, 1 May 2025 13:33:56 -0700 Subject: [PATCH] Refactor attention v2 Pull attention creation out of Transformer/TransformerBlock. Instead, pass the layers into Transformer. The motivation is to customize linear layers in attention for LoRA (eg. make wq into a LoraLinear instead of a regular linear). In the next diff (D73517350), we pull wq,wk,wv,wo out of the attention and pass those in as well. This allows us to customize attention parameters without passing in ModelArgs and doing the customization deep inside attention.py. I think this modularizes our attention/transformer components, though also means that users have to do some more work to construct the attention layers and pass it to transformer. It follows the torchtune structure more closely, eg. https://github.com/pytorch/torchtune/blob/main/torchtune/models/llama3_2/_component_builders.py#L221 Differential Revision: [D73538697](https://our.internmc.facebook.com/intern/diff/D73538697/) [ghstack-poisoned] --- examples/models/llama/llama_transformer.py | 50 ++++++++++++++----- examples/models/llama/model.py | 5 +- .../tests/test_pre_quantization_transforms.py | 4 +- .../llama/tests/test_static_attention.py | 6 +-- 4 files changed, 45 insertions(+), 20 deletions(-) diff --git a/examples/models/llama/llama_transformer.py b/examples/models/llama/llama_transformer.py index 5c8db7f208d..0dd231b7ecb 100644 --- a/examples/models/llama/llama_transformer.py +++ b/examples/models/llama/llama_transformer.py @@ -13,6 +13,7 @@ import torch.nn.functional as F from executorch.examples.models.llama.attention import ( + Attention, ATTENTION_REGISTRY, ForwardOptions, ) @@ -83,25 +84,30 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class TransformerBlock(nn.Module): - def __init__(self, layer_id: int, args: ModelArgs, rope: Rope): + def __init__(self, args: ModelArgs, attention: Attention): super().__init__() self.use_kv_cache = args.use_kv_cache self.n_heads = args.n_heads self.dim = args.dim self.head_dim = args.head_dim - if args.attention_type not in ATTENTION_REGISTRY: - raise ValueError( - f"Unknown attention type: {args.attention_type}. " - f"Available: {list(ATTENTION_REGISTRY.keys())}" - ) - cls = ATTENTION_REGISTRY[args.attention_type] - self.attention = cls(args, layer_id, rope) + self.attention = attention if args.moe: self.block_sparse_moe = MOEFeedForward(args) else: self.feed_forward = FeedForward(args) self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) + + @classmethod + def from_type(cls, layer_id, args, rope) -> "TransformerBlock": + if args.attention_type not in ATTENTION_REGISTRY: + raise ValueError( + f"Unknown attention type: {args.attention_type}. " + f"Available: {list(ATTENTION_REGISTRY.keys())}" + ) + cls = ATTENTION_REGISTRY[args.attention_type] + attention = cls(args, layer_id, rope) + return TransformerBlock(args, attention) def forward(self, x, freqs_cos, freqs_sin, attn_options: ForwardOptions): # x: 1xN h, attn_options_update = self.attention.forward( @@ -117,7 +123,7 @@ def forward(self, x, freqs_cos, freqs_sin, attn_options: ForwardOptions): # x: class Transformer(nn.Module): - def __init__(self, params: ModelArgs): + def __init__(self, params: ModelArgs, layers: nn.ModuleList, rope: Rope): super().__init__() self.params = params self.vocab_size = params.vocab_size @@ -130,10 +136,8 @@ def __init__(self, params: ModelArgs): if self.apply_embedding else None ) - self.rope = Rope(params) - self.layers = torch.nn.ModuleList() - for layer_id in range(params.n_layers): - self.layers.append(TransformerBlock(layer_id, params, self.rope)) + self.layers = layers + self.rope = rope self.norm = RMSNorm(params.dim, eps=params.norm_eps) self.output = ( nn.Linear(params.dim, params.vocab_size, bias=False) @@ -212,3 +216,23 @@ def forward( return logits, attn_options_update return logits + + +def construct_transformer(model_args: ModelArgs) -> Transformer: + """ + Construct a Transformer model from the given model arguments. + """ + rope = Rope(model_args) + if model_args.attention_type not in ATTENTION_REGISTRY: + raise ValueError( + f"Unknown attention type: {model_args.attention_type}. " + f"Available: {list(ATTENTION_REGISTRY.keys())}" + ) + layers = torch.nn.ModuleList() + cls = ATTENTION_REGISTRY[model_args.attention_type] + for layer_id in range(model_args.n_layers): + attention = cls(model_args, layer_id, rope) + transformer_block = TransformerBlock(model_args, attention) + layers.append(transformer_block) + + return Transformer(model_args, layers, rope) diff --git a/examples/models/llama/model.py b/examples/models/llama/model.py index 2c82841c573..d6400c29db8 100644 --- a/examples/models/llama/model.py +++ b/examples/models/llama/model.py @@ -15,9 +15,10 @@ get_checkpoint_dtype, get_default_model_resource_dir, ) -from executorch.examples.models.llama.llama_transformer import Transformer +from executorch.examples.models.llama.llama_transformer import construct_transformer from executorch.examples.models.llama.model_args import ModelArgs +from executorch.examples.models.llama.rope import Rope from torchao.utils import TorchAOBaseTensor try: @@ -174,7 +175,7 @@ def __init__(self, **kwargs): # They possess all other metadata a tensor carries such as size, stride, requires_grad. with torch.device("meta"): # Model itself is loaded in default dtype, fp32. - self.model_ = Transformer(model_args) + self.model_ = construct_transformer(model_args) # Get checkpoint dtype. if checkpoint: self.model_.checkpoint_dtype = get_checkpoint_dtype(checkpoint) diff --git a/examples/models/llama/tests/test_pre_quantization_transforms.py b/examples/models/llama/tests/test_pre_quantization_transforms.py index 345f3fad9ba..3a55f93d562 100644 --- a/examples/models/llama/tests/test_pre_quantization_transforms.py +++ b/examples/models/llama/tests/test_pre_quantization_transforms.py @@ -7,7 +7,7 @@ import unittest import torch -from executorch.examples.models.llama.llama_transformer import Transformer +from executorch.examples.models.llama.llama_transformer import construct_transformer, Transformer from executorch.examples.models.llama.model_args import ModelArgs from executorch.examples.models.llama.source_transformation.pre_quantization import ( sanitize_checkpoint_from_pre_quantization, @@ -39,7 +39,7 @@ def _prepare_dummy_model(self) -> Transformer: vocab_size=32000, ) - model = Transformer(model_args) + model = construct_transformer(model_args) return model diff --git a/examples/models/llama/tests/test_static_attention.py b/examples/models/llama/tests/test_static_attention.py index a1b6742416e..77b8be5d401 100644 --- a/examples/models/llama/tests/test_static_attention.py +++ b/examples/models/llama/tests/test_static_attention.py @@ -2,7 +2,7 @@ import torch from executorch.examples.models.llama.attention import AttentionMHA, ForwardOptions -from executorch.examples.models.llama.llama_transformer import Transformer +from executorch.examples.models.llama.llama_transformer import construct_transformer from executorch.examples.models.llama.model_args import ModelArgs from executorch.examples.models.llama.rope import Rope from executorch.examples.models.llama.static_attention import ( @@ -160,10 +160,10 @@ def test_within_transformer(self): n_layers=4, vocab_size=128, ) - mha_transformer = Transformer(config).eval() + mha_transformer = construct_transformer(config).eval() config.attention_type = "static" - static_transformer = Transformer(config).eval() + static_transformer = construct_transformer(config).eval() static_transformer.load_state_dict(mha_transformer.state_dict(), strict=False) for mha_layer, static_layer in zip( mha_transformer.layers, static_transformer.layers