Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions examples/models/lfm2/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from executorch.examples.models.lfm2.convert_weights import convert_weights
from executorch.examples.models.lfm2.short_conv import ShortConvBlock

__all__ = [
"convert_weights",
"ShortConvBlock",
]
8 changes: 5 additions & 3 deletions examples/models/lfm2/short_conv.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

import torch
from executorch.examples.models.llama.attention import ForwardOptions
from executorch.examples.models.llama.feed_forward import FeedForward
Expand All @@ -12,8 +14,8 @@ def __init__(
dim: int,
L_cache: int = 3,
bias: bool = False,
device: torch.device = None,
dtype: torch.dtype = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
super().__init__()
self.dim = dim
Expand Down Expand Up @@ -99,7 +101,7 @@ def forward(
x,
freqs_cos=None,
freqs_sin=None,
_unused_attn_options: ForwardOptions = None,
_unused_attn_options: Optional[ForwardOptions] = None,
): # x: 1xN
h = self.conv.forward(self.attention_norm(x))
h = x + h
Expand Down
21 changes: 18 additions & 3 deletions examples/models/llama/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,24 @@ runtime.python_library(
name = "llama_transformer",
srcs = [
"llama_transformer.py",
],
_is_external_target = True,
base_module = "executorch.examples.models.llama",
visibility = [
"//executorch/...",
"@EXECUTORCH_CLIENTS",
],
deps = [
":transformer_modules",
"//caffe2:torch",
"//executorch/examples/models/lfm2:lfm2",
],
)

runtime.python_library(
name = "transformer_modules",
srcs = [
"feed_forward.py",
"lora.py",
"rope.py",
"attention.py",
Expand All @@ -25,9 +43,6 @@ runtime.python_library(
"//executorch/...",
"@EXECUTORCH_CLIENTS",
],
deps = [
"//caffe2:torch",
],
)

runtime.python_library(
Expand Down
1 change: 0 additions & 1 deletion examples/models/llama/feed_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
class FeedForward(nn.Module):
def __init__(self, dim: int, hidden_dim: int):
super().__init__()
assert hidden_dim is not None
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
Expand Down
15 changes: 13 additions & 2 deletions examples/models/llama/llama_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import torch
import torch.nn.functional as F

from executorch.examples.models.lfm2.short_conv import ShortConvBlock
from executorch.examples.models.llama.attention import (
Attention,
ATTENTION_REGISTRY,
Expand Down Expand Up @@ -87,10 +86,15 @@ def __init__(self, args: ModelArgs, attention: Attention):
self.dim = args.dim
self.head_dim = args.head_dim
self.attention = attention

assert (
args.hidden_dim is not None
), "`hidden_dim` must be set in ModelArgs to construct a TransformerBlock."
if args.moe:
self.block_sparse_moe = MOEFeedForward(args)
else:
self.feed_forward = FeedForward(dim=args.dim, hidden_dim=args.hidden_dim)

self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)

Expand Down Expand Up @@ -245,6 +249,11 @@ def construct_transformer(model_args: ModelArgs) -> Transformer:
for layer_id in range(model_args.n_layers):
# hybrid models define layer_types
if model_args.layer_types and model_args.layer_types[layer_id] == "conv":
from executorch.examples.models.lfm2.short_conv import ShortConvBlock

assert (
model_args.hidden_dim is not None
), "`hidden_dim` must be set in ModelArgs to construct a TransformerBlock."
layers.append(
ShortConvBlock(
dim=model_args.dim,
Expand All @@ -253,7 +262,9 @@ def construct_transformer(model_args: ModelArgs) -> Transformer:
)
)
else:
attention = cls(model_args, layer_id, rope, **model_args.attention_kwargs)
attention = cls(
model_args, layer_id, rope, **model_args.attention_kwargs
) # pyre-ignore[45]
transformer_block = TransformerBlock(model_args, attention)
layers.append(transformer_block)

Expand Down
Loading