Skip to content

Commit e008d6c

Browse files
committed
Address PR feedback
1 parent 0ea6faa commit e008d6c

File tree

4 files changed

+34
-25
lines changed

4 files changed

+34
-25
lines changed

examples/models/lfm2/convert_weights.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,13 @@
1212
_LFM_2_TO_META = {
1313
"model.embed_tokens.weight": "tok_embeddings.weight",
1414
"model.embedding_norm.weight": "norm.weight",
15-
1615
"model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight",
1716
"model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight",
1817
"model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight",
1918
"model.layers.{}.self_attn.out_proj.weight": "layers.{}.attention.wo.weight",
2019
"model.layers.{}.self_attn.k_layernorm.weight": "layers.{}.attention.k_norm_fn.weight",
2120
"model.layers.{}.self_attn.q_layernorm.weight": "layers.{}.attention.q_norm_fn.weight",
22-
2321
"model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight",
24-
2522
"model.layers.{}.operator_norm.weight": "layers.{}.attention_norm.weight",
2623
}
2724

@@ -48,7 +45,9 @@ def lfm_2_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor
4845

4946
# split in_proj
5047
if new_key.endswith(".conv.in_proj.weight"):
51-
for name, split_value in zip(["B_proj", "C_proj", "x_proj"], torch.chunk(value, 3, dim=0)):
48+
for name, split_value in zip(
49+
["B_proj", "C_proj", "x_proj"], torch.chunk(value, 3, dim=0)
50+
):
5251
converted_state_dict[new_key.replace("in_proj", name)] = split_value
5352
else:
5453
converted_state_dict[new_key] = value

examples/models/lfm2/short_conv.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import torch
2-
from torch import nn
3-
4-
from executorch.examples.models.llama.norm import RMSNorm
52
from executorch.examples.models.llama.attention import ForwardOptions
63
from executorch.examples.models.llama.feed_forward import FeedForward
74

5+
from executorch.examples.models.llama.norm import RMSNorm
6+
from torch import nn
7+
88

99
class ShortConv(nn.Module):
1010
def __init__(
@@ -61,10 +61,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
6161
# So, assuming prefill is done on an empty cache, concatenating conv_state to the beginning of the sequence acts similary to
6262
## using nn.Conv1d(padding=L_cache-1) (for prefill) without no manual padding.
6363
## However, the manual padding has the added benefit of being correct during decode, when the cache is not initialized to 0.
64-
Bx = torch.cat([self.conv_state, Bx], dim=-1) # (batch_size, dim, seq_len + L_cache - 1)
64+
Bx = torch.cat(
65+
[self.conv_state, Bx], dim=-1
66+
) # (batch_size, dim, seq_len + L_cache - 1)
6567

6668
## Update the conv_state
67-
new_conv_state = Bx[..., -(self.conv.weight.size(-1) - 1) :] # (batch_size, dim, L_cache - 1)
69+
new_conv_state = Bx[
70+
..., -(self.L_cache - 1) :
71+
] # (batch_size, dim, L_cache - 1)
6872
with torch.no_grad():
6973
self.conv_state.copy_(new_conv_state)
7074

@@ -83,15 +87,20 @@ def reset_cache(self):
8387
class ShortConvBlock(nn.Module):
8488
def __init__(self, dim: int, hidden_dim: int, norm_eps: float):
8589
super().__init__()
86-
# hardcode 3 for now
87-
L_cache = 3
88-
self.conv = ShortConv(dim, L_cache, bias=False)
90+
self.L_cache = 3 # hardcode 3 for now
91+
self.conv = ShortConv(dim, self.L_cache, bias=False)
8992
self.feed_forward = FeedForward(dim, hidden_dim)
9093
self.ffn_norm = RMSNorm(dim, norm_eps)
9194
# use attention_norm norm instead of operator_norm to unify with TransformerBlock
9295
self.attention_norm = RMSNorm(dim, norm_eps)
9396

94-
def forward(self, x, freqs_cos=None, freqs_sin=None, _unused_attn_options: ForwardOptions = None): # x: 1xN
97+
def forward(
98+
self,
99+
x,
100+
freqs_cos=None,
101+
freqs_sin=None,
102+
_unused_attn_options: ForwardOptions = None,
103+
): # x: 1xN
95104
h = self.conv.forward(self.attention_norm(x))
96105
h = x + h
97106
out = h + self.feed_forward(self.ffn_norm(h))

examples/models/llama/feed_forward.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
from torch import nn
21
import torch.nn.functional as F
2+
from torch import nn
3+
34

45
class FeedForward(nn.Module):
56
def __init__(self, dim: int, hidden_dim: int):
@@ -11,5 +12,3 @@ def __init__(self, dim: int, hidden_dim: int):
1112

1213
def forward(self, x):
1314
return self.w2(F.silu(self.w1(x)) * self.w3(x))
14-
15-

examples/models/llama/llama_transformer.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,16 @@
1212
import torch
1313
import torch.nn.functional as F
1414

15+
from executorch.examples.models.lfm2.short_conv import ShortConvBlock
1516
from executorch.examples.models.llama.attention import (
1617
Attention,
1718
ATTENTION_REGISTRY,
1819
ForwardOptions,
1920
)
20-
from executorch.examples.models.lfm2.short_conv import ShortConvBlock
21-
21+
from executorch.examples.models.llama.feed_forward import FeedForward
2222
from executorch.examples.models.llama.model_args import ModelArgs
2323
from executorch.examples.models.llama.norm import RMSNorm
2424
from executorch.examples.models.llama.rope import Rope
25-
from executorch.examples.models.llama.feed_forward import FeedForward
2625
from torch import nn
2726

2827

@@ -247,12 +246,15 @@ def construct_transformer(model_args: ModelArgs) -> Transformer:
247246
# hybrid models define layer_types
248247
if model_args.layer_types and model_args.layer_types[layer_id] == "conv":
249248
layers.append(
250-
ShortConvBlock(dim=model_args.dim, hidden_dim=model_args.hidden_dim, norm_eps=model_args.norm_eps)
249+
ShortConvBlock(
250+
dim=model_args.dim,
251+
hidden_dim=model_args.hidden_dim,
252+
norm_eps=model_args.norm_eps,
253+
)
251254
)
252-
continue
253-
254-
attention = cls(model_args, layer_id, rope, **model_args.attention_kwargs)
255-
transformer_block = TransformerBlock(model_args, attention)
256-
layers.append(transformer_block)
255+
else:
256+
attention = cls(model_args, layer_id, rope, **model_args.attention_kwargs)
257+
transformer_block = TransformerBlock(model_args, attention)
258+
layers.append(transformer_block)
257259

258260
return Transformer(model_args, layers, rope)

0 commit comments

Comments
 (0)