diff --git a/examples/models/lfm2/__init__.py b/examples/models/lfm2/__init__.py index 224282df905..1efdc55af81 100644 --- a/examples/models/lfm2/__init__.py +++ b/examples/models/lfm2/__init__.py @@ -1,5 +1,10 @@ +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + from executorch.examples.models.lfm2.convert_weights import convert_weights +from executorch.examples.models.lfm2.short_conv import ShortConvBlock __all__ = [ "convert_weights", + "ShortConvBlock", ] diff --git a/examples/models/lfm2/short_conv.py b/examples/models/lfm2/short_conv.py index 5a141d4ce61..ae04580d6c6 100644 --- a/examples/models/lfm2/short_conv.py +++ b/examples/models/lfm2/short_conv.py @@ -1,3 +1,5 @@ +from typing import Optional + import torch from executorch.examples.models.llama.attention import ForwardOptions from executorch.examples.models.llama.feed_forward import FeedForward @@ -12,8 +14,8 @@ def __init__( dim: int, L_cache: int = 3, bias: bool = False, - device: torch.device = None, - dtype: torch.dtype = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, ): super().__init__() self.dim = dim @@ -99,7 +101,7 @@ def forward( x, freqs_cos=None, freqs_sin=None, - _unused_attn_options: ForwardOptions = None, + _unused_attn_options: Optional[ForwardOptions] = None, ): # x: 1xN h = self.conv.forward(self.attention_norm(x)) h = x + h diff --git a/examples/models/llama/TARGETS b/examples/models/llama/TARGETS index c4870ece193..fe26b2f08e0 100644 --- a/examples/models/llama/TARGETS +++ b/examples/models/llama/TARGETS @@ -13,6 +13,24 @@ runtime.python_library( name = "llama_transformer", srcs = [ "llama_transformer.py", + ], + _is_external_target = True, + base_module = "executorch.examples.models.llama", + visibility = [ + "//executorch/...", + "@EXECUTORCH_CLIENTS", + ], + deps = [ + ":transformer_modules", + "//caffe2:torch", + "//executorch/examples/models/lfm2:lfm2", + ], +) + +runtime.python_library( + name = "transformer_modules", + srcs = [ + "feed_forward.py", "lora.py", "rope.py", "attention.py", @@ -25,9 +43,6 @@ runtime.python_library( "//executorch/...", "@EXECUTORCH_CLIENTS", ], - deps = [ - "//caffe2:torch", - ], ) runtime.python_library( diff --git a/examples/models/llama/feed_forward.py b/examples/models/llama/feed_forward.py index 3e7af7e0dc8..21a4e27df04 100644 --- a/examples/models/llama/feed_forward.py +++ b/examples/models/llama/feed_forward.py @@ -5,7 +5,6 @@ class FeedForward(nn.Module): def __init__(self, dim: int, hidden_dim: int): super().__init__() - assert hidden_dim is not None self.w1 = nn.Linear(dim, hidden_dim, bias=False) self.w2 = nn.Linear(hidden_dim, dim, bias=False) self.w3 = nn.Linear(dim, hidden_dim, bias=False) diff --git a/examples/models/llama/llama_transformer.py b/examples/models/llama/llama_transformer.py index cdbb0c7557c..3a325d0f4f8 100644 --- a/examples/models/llama/llama_transformer.py +++ b/examples/models/llama/llama_transformer.py @@ -12,7 +12,6 @@ import torch import torch.nn.functional as F -from executorch.examples.models.lfm2.short_conv import ShortConvBlock from executorch.examples.models.llama.attention import ( Attention, ATTENTION_REGISTRY, @@ -87,10 +86,15 @@ def __init__(self, args: ModelArgs, attention: Attention): self.dim = args.dim self.head_dim = args.head_dim self.attention = attention + + assert ( + args.hidden_dim is not None + ), "`hidden_dim` must be set in ModelArgs to construct a TransformerBlock." if args.moe: self.block_sparse_moe = MOEFeedForward(args) else: self.feed_forward = FeedForward(dim=args.dim, hidden_dim=args.hidden_dim) + self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) @@ -245,6 +249,11 @@ def construct_transformer(model_args: ModelArgs) -> Transformer: for layer_id in range(model_args.n_layers): # hybrid models define layer_types if model_args.layer_types and model_args.layer_types[layer_id] == "conv": + from executorch.examples.models.lfm2.short_conv import ShortConvBlock + + assert ( + model_args.hidden_dim is not None + ), "`hidden_dim` must be set in ModelArgs to construct a TransformerBlock." layers.append( ShortConvBlock( dim=model_args.dim, @@ -253,7 +262,9 @@ def construct_transformer(model_args: ModelArgs) -> Transformer: ) ) else: - attention = cls(model_args, layer_id, rope, **model_args.attention_kwargs) + attention = cls( + model_args, layer_id, rope, **model_args.attention_kwargs + ) # pyre-ignore[45] transformer_block = TransformerBlock(model_args, attention) layers.append(transformer_block)