diff --git a/examples/models/llama/llama_transformer.py b/examples/models/llama/llama_transformer.py index 5c8db7f208d..1fdcdcd91fc 100644 --- a/examples/models/llama/llama_transformer.py +++ b/examples/models/llama/llama_transformer.py @@ -13,6 +13,7 @@ import torch.nn.functional as F from executorch.examples.models.llama.attention import ( + Attention, ATTENTION_REGISTRY, ForwardOptions, ) @@ -83,19 +84,21 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class TransformerBlock(nn.Module): - def __init__(self, layer_id: int, args: ModelArgs, rope: Rope): + def __init__(self, args: ModelArgs, attention: Attention): + """ + Transformer block with support for pre-norm and post-norm. + Args: + args (ModelArgs): model configuration parameters. + attention (Attention): attention object to use in the transformer + block. See `attention.py` for types of attention. Make sure + the attention type is registered in the ATTENTION_REGISTRY. + """ super().__init__() self.use_kv_cache = args.use_kv_cache self.n_heads = args.n_heads self.dim = args.dim self.head_dim = args.head_dim - if args.attention_type not in ATTENTION_REGISTRY: - raise ValueError( - f"Unknown attention type: {args.attention_type}. " - f"Available: {list(ATTENTION_REGISTRY.keys())}" - ) - cls = ATTENTION_REGISTRY[args.attention_type] - self.attention = cls(args, layer_id, rope) + self.attention = attention if args.moe: self.block_sparse_moe = MOEFeedForward(args) else: @@ -103,6 +106,24 @@ def __init__(self, layer_id: int, args: ModelArgs, rope: Rope): self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) + @classmethod + def from_type(cls, layer_id, args, rope) -> "TransformerBlock": + """ + Create a TransformerBlock with the legacy constructor. + Args: + layer_id (int): the index of the layer. + args (ModelArgs): model configuration parameters. + rope (Rope): the rope object to use for rotary embeddings. + """ + if args.attention_type not in ATTENTION_REGISTRY: + raise ValueError( + f"Unknown attention type: {args.attention_type}. " + f"Available: {list(ATTENTION_REGISTRY.keys())}" + ) + cls = ATTENTION_REGISTRY[args.attention_type] + attention = cls(args, layer_id, rope) + return TransformerBlock(args, attention) + def forward(self, x, freqs_cos, freqs_sin, attn_options: ForwardOptions): # x: 1xN h, attn_options_update = self.attention.forward( self.attention_norm(x), freqs_cos, freqs_sin, **attn_options @@ -117,7 +138,15 @@ def forward(self, x, freqs_cos, freqs_sin, attn_options: ForwardOptions): # x: class Transformer(nn.Module): - def __init__(self, params: ModelArgs): + def __init__(self, params: ModelArgs, layers: nn.ModuleList, rope: Rope): + """ + Transformer model. + Args: + params (ModelArgs): model configuration parameters. + layers (nn.ModuleList): list of transformer blocks - see the + `TransformerBlock` type above. + rope (Rope): the rope object to use for rotary embeddings. + """ super().__init__() self.params = params self.vocab_size = params.vocab_size @@ -130,10 +159,8 @@ def __init__(self, params: ModelArgs): if self.apply_embedding else None ) - self.rope = Rope(params) - self.layers = torch.nn.ModuleList() - for layer_id in range(params.n_layers): - self.layers.append(TransformerBlock(layer_id, params, self.rope)) + self.layers = layers + self.rope = rope self.norm = RMSNorm(params.dim, eps=params.norm_eps) self.output = ( nn.Linear(params.dim, params.vocab_size, bias=False) @@ -212,3 +239,23 @@ def forward( return logits, attn_options_update return logits + + +def construct_transformer(model_args: ModelArgs) -> Transformer: + """ + Construct a Transformer model from the given model arguments. + """ + rope = Rope(model_args) + if model_args.attention_type not in ATTENTION_REGISTRY: + raise ValueError( + f"Unknown attention type: {model_args.attention_type}. " + f"Available: {list(ATTENTION_REGISTRY.keys())}" + ) + layers = torch.nn.ModuleList() + cls = ATTENTION_REGISTRY[model_args.attention_type] + for layer_id in range(model_args.n_layers): + attention = cls(model_args, layer_id, rope) + transformer_block = TransformerBlock(model_args, attention) + layers.append(transformer_block) + + return Transformer(model_args, layers, rope) diff --git a/examples/models/llama/model.py b/examples/models/llama/model.py index 2c82841c573..d6400c29db8 100644 --- a/examples/models/llama/model.py +++ b/examples/models/llama/model.py @@ -15,9 +15,10 @@ get_checkpoint_dtype, get_default_model_resource_dir, ) -from executorch.examples.models.llama.llama_transformer import Transformer +from executorch.examples.models.llama.llama_transformer import construct_transformer from executorch.examples.models.llama.model_args import ModelArgs +from executorch.examples.models.llama.rope import Rope from torchao.utils import TorchAOBaseTensor try: @@ -174,7 +175,7 @@ def __init__(self, **kwargs): # They possess all other metadata a tensor carries such as size, stride, requires_grad. with torch.device("meta"): # Model itself is loaded in default dtype, fp32. - self.model_ = Transformer(model_args) + self.model_ = construct_transformer(model_args) # Get checkpoint dtype. if checkpoint: self.model_.checkpoint_dtype = get_checkpoint_dtype(checkpoint) diff --git a/examples/models/llama/tests/test_pre_quantization_transforms.py b/examples/models/llama/tests/test_pre_quantization_transforms.py index 345f3fad9ba..dc1f9c6cd71 100644 --- a/examples/models/llama/tests/test_pre_quantization_transforms.py +++ b/examples/models/llama/tests/test_pre_quantization_transforms.py @@ -7,7 +7,10 @@ import unittest import torch -from executorch.examples.models.llama.llama_transformer import Transformer +from executorch.examples.models.llama.llama_transformer import ( + construct_transformer, + Transformer, +) from executorch.examples.models.llama.model_args import ModelArgs from executorch.examples.models.llama.source_transformation.pre_quantization import ( sanitize_checkpoint_from_pre_quantization, @@ -39,7 +42,7 @@ def _prepare_dummy_model(self) -> Transformer: vocab_size=32000, ) - model = Transformer(model_args) + model = construct_transformer(model_args) return model diff --git a/examples/models/llama/tests/test_static_attention.py b/examples/models/llama/tests/test_static_attention.py index a1b6742416e..77b8be5d401 100644 --- a/examples/models/llama/tests/test_static_attention.py +++ b/examples/models/llama/tests/test_static_attention.py @@ -2,7 +2,7 @@ import torch from executorch.examples.models.llama.attention import AttentionMHA, ForwardOptions -from executorch.examples.models.llama.llama_transformer import Transformer +from executorch.examples.models.llama.llama_transformer import construct_transformer from executorch.examples.models.llama.model_args import ModelArgs from executorch.examples.models.llama.rope import Rope from executorch.examples.models.llama.static_attention import ( @@ -160,10 +160,10 @@ def test_within_transformer(self): n_layers=4, vocab_size=128, ) - mha_transformer = Transformer(config).eval() + mha_transformer = construct_transformer(config).eval() config.attention_type = "static" - static_transformer = Transformer(config).eval() + static_transformer = construct_transformer(config).eval() static_transformer.load_state_dict(mha_transformer.state_dict(), strict=False) for mha_layer, static_layer in zip( mha_transformer.layers, static_transformer.layers diff --git a/examples/models/llava/model.py b/examples/models/llava/model.py index 351356607c8..7bcf560536c 100644 --- a/examples/models/llava/model.py +++ b/examples/models/llava/model.py @@ -12,7 +12,7 @@ import requests import torch -from executorch.examples.models.llama.llama_transformer import Transformer +from executorch.examples.models.llama.llama_transformer import construct_transformer from executorch.examples.models.llama.model_args import ModelArgs from executorch.examples.models.llama.source_transformation.custom_kv_cache import ( @@ -66,7 +66,7 @@ def __init__( use_hf_rope=True, max_seq_len=max_seq_len, ) - self.text_model = Transformer(self.text_model_args) + self.text_model = construct_transformer(self.text_model_args) # use custom op for SDPA. if use_sdpa_with_kv_cache_op: self.text_model = replace_kv_cache_with_custom_kv_cache(self.text_model)