Skip to content

Commit 4628f89

Browse files
authored
Fix internal tests for LiquidAI LFM2
Differential Revision: D81491136 Pull Request resolved: pytorch#13916
1 parent f3b902a commit 4628f89

File tree

5 files changed

+41
-9
lines changed

5 files changed

+41
-9
lines changed

examples/models/lfm2/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
1+
# This source code is licensed under the BSD-style license found in the
2+
# LICENSE file in the root directory of this source tree.
3+
14
from executorch.examples.models.lfm2.convert_weights import convert_weights
5+
from executorch.examples.models.lfm2.short_conv import ShortConvBlock
26

37
__all__ = [
48
"convert_weights",
9+
"ShortConvBlock",
510
]

examples/models/lfm2/short_conv.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Optional
2+
13
import torch
24
from executorch.examples.models.llama.attention import ForwardOptions
35
from executorch.examples.models.llama.feed_forward import FeedForward
@@ -12,8 +14,8 @@ def __init__(
1214
dim: int,
1315
L_cache: int = 3,
1416
bias: bool = False,
15-
device: torch.device = None,
16-
dtype: torch.dtype = None,
17+
device: Optional[torch.device] = None,
18+
dtype: Optional[torch.dtype] = None,
1719
):
1820
super().__init__()
1921
self.dim = dim
@@ -99,7 +101,7 @@ def forward(
99101
x,
100102
freqs_cos=None,
101103
freqs_sin=None,
102-
_unused_attn_options: ForwardOptions = None,
104+
_unused_attn_options: Optional[ForwardOptions] = None,
103105
): # x: 1xN
104106
h = self.conv.forward(self.attention_norm(x))
105107
h = x + h

examples/models/llama/TARGETS

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,24 @@ runtime.python_library(
1313
name = "llama_transformer",
1414
srcs = [
1515
"llama_transformer.py",
16+
],
17+
_is_external_target = True,
18+
base_module = "executorch.examples.models.llama",
19+
visibility = [
20+
"//executorch/...",
21+
"@EXECUTORCH_CLIENTS",
22+
],
23+
deps = [
24+
":transformer_modules",
25+
"//caffe2:torch",
26+
"//executorch/examples/models/lfm2:lfm2",
27+
],
28+
)
29+
30+
runtime.python_library(
31+
name = "transformer_modules",
32+
srcs = [
33+
"feed_forward.py",
1634
"lora.py",
1735
"rope.py",
1836
"attention.py",
@@ -25,9 +43,6 @@ runtime.python_library(
2543
"//executorch/...",
2644
"@EXECUTORCH_CLIENTS",
2745
],
28-
deps = [
29-
"//caffe2:torch",
30-
],
3146
)
3247

3348
runtime.python_library(

examples/models/llama/feed_forward.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
class FeedForward(nn.Module):
66
def __init__(self, dim: int, hidden_dim: int):
77
super().__init__()
8-
assert hidden_dim is not None
98
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
109
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
1110
self.w3 = nn.Linear(dim, hidden_dim, bias=False)

examples/models/llama/llama_transformer.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import torch
1313
import torch.nn.functional as F
1414

15-
from executorch.examples.models.lfm2.short_conv import ShortConvBlock
1615
from executorch.examples.models.llama.attention import (
1716
Attention,
1817
ATTENTION_REGISTRY,
@@ -87,10 +86,15 @@ def __init__(self, args: ModelArgs, attention: Attention):
8786
self.dim = args.dim
8887
self.head_dim = args.head_dim
8988
self.attention = attention
89+
90+
assert (
91+
args.hidden_dim is not None
92+
), "`hidden_dim` must be set in ModelArgs to construct a TransformerBlock."
9093
if args.moe:
9194
self.block_sparse_moe = MOEFeedForward(args)
9295
else:
9396
self.feed_forward = FeedForward(dim=args.dim, hidden_dim=args.hidden_dim)
97+
9498
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
9599
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
96100

@@ -245,6 +249,11 @@ def construct_transformer(model_args: ModelArgs) -> Transformer:
245249
for layer_id in range(model_args.n_layers):
246250
# hybrid models define layer_types
247251
if model_args.layer_types and model_args.layer_types[layer_id] == "conv":
252+
from executorch.examples.models.lfm2.short_conv import ShortConvBlock
253+
254+
assert (
255+
model_args.hidden_dim is not None
256+
), "`hidden_dim` must be set in ModelArgs to construct a TransformerBlock."
248257
layers.append(
249258
ShortConvBlock(
250259
dim=model_args.dim,
@@ -253,7 +262,9 @@ def construct_transformer(model_args: ModelArgs) -> Transformer:
253262
)
254263
)
255264
else:
256-
attention = cls(model_args, layer_id, rope, **model_args.attention_kwargs)
265+
attention = cls(
266+
model_args, layer_id, rope, **model_args.attention_kwargs
267+
) # pyre-ignore[45]
257268
transformer_block = TransformerBlock(model_args, attention)
258269
layers.append(transformer_block)
259270

0 commit comments

Comments
 (0)