Skip to content

Commit 8607c89

Browse files
tdakhranjackzhxng
andauthored
model : support LiquidAI LFM2 hybrid family (pytorch#13805)
### Summary Add support for [LiquidAI LFM2](https://huggingface.co/collections/LiquidAI/lfm2-686d721927015b2ad73eaa38) model family. For more information about models, please read [the blog post](https://www.liquid.ai/blog/liquid-foundation-models-v2-our-second-series-of-generative-ai-models). - Support hybrid LFM2-350M, LFM2-700M, and LFM2-1.2B models. - Add `ShortConvBlock`. - Modify `construct_transformer` to construct hybrid architectures. - Move FeedForward to avoid cyclid dependency Instructions are in `examples/models/lfm2/README.md`. ### Test plan All commands in `README.md` are tests. ``` ❯ python -m examples.models.llama.runner.native \ --model lfm2_700m \ --pte lfm2_700m_8da4w.pte \ --tokenizer ~/.cache/huggingface/hub/models--LiquidAI--LFM2-700M/snapshots/ab260293733f05dd4ce22399bea1cae2cf9b272d/tokenizer.json \ --tokenizer_config ~/.cache/huggingface/hub/models--LiquidAI--LFM2-700M/snapshots/ab260293733f05dd4ce22399bea1cae2cf9b272d/tokenizer_config.json \ --prompt "<|startoftext|><|im_start|>user\nWho are you?<|im_end|>\n<|im_start|>assistant\n" \ --params examples/models/lfm2/config/lfm2_700m_config.json \ --max_len 128 \ -kv \ --temperature 0.3 ... I'm an AI designed to assist with generating text based on the prompts you provide. I'm a type of language model, but I don't have a physical form or consciousness. I operate based on complex algorithms and vast amounts of training data. How can I help you today? If you have a specific question or need assistance with something, feel free to ask! ... ``` --------- Co-authored-by: Jack <[email protected]>
1 parent 0eed262 commit 8607c89

File tree

15 files changed

+424
-19
lines changed

15 files changed

+424
-19
lines changed

examples/models/lfm2/README.md

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
## Summary
2+
[LFM2](https://huggingface.co/collections/LiquidAI/lfm2-686d721927015b2ad73eaa38) is a new generation of hybrid models developed by [Liquid AI](https://www.liquid.ai/) and available in 3 variants - 350M, 700M, 1.2B.
3+
4+
## Instructions
5+
6+
LFM2 uses the same example code as optimized Llama model, while the checkpoint, model params, and tokenizer are different. Please see the [Llama README page](../llama/README.md) for details.
7+
LFM2 is a hybrid model, where some attention layers are replaced with short convolutions.
8+
9+
### Example export
10+
Here is a basic example for exporting LFM2, although please refer to the Llama README's [Step 2: Prepare model](../llama/README.md#step-2-prepare-model) for more advanced usage.
11+
12+
Export 350m to XNNPack, quantized with 8da4w:
13+
```
14+
python -m extension.llm.export.export_llm \
15+
--config examples/models/lfm2/config/lfm2_xnnpack_q8da4w.yaml \
16+
+base.model_class="lfm2_350m" \
17+
+base.params="examples/models/lfm2/config/lfm2_350m_config.json" \
18+
+export.output_name="lfm2_350m_8da4w.pte"
19+
```
20+
21+
Export 700m to XNNPack, quantized with 8da4w:
22+
```
23+
python -m extension.llm.export.export_llm \
24+
--config examples/models/lfm2/config/lfm2_xnnpack_q8da4w.yaml \
25+
+base.model_class="lfm2_700m" \
26+
+base.params="examples/models/lfm2/config/lfm2_700m_config.json" \
27+
+export.output_name="lfm2_700m_8da4w.pte"
28+
```
29+
30+
Export 1_2b to XNNPack, quantized with 8da4w:
31+
```
32+
python -m extension.llm.export.export_llm \
33+
--config examples/models/lfm2/config/lfm2_xnnpack_q8da4w.yaml \
34+
+base.model_class="lfm2_1_2b" \
35+
+base.params="examples/models/lfm2/config/lfm2_1_2b_config.json" \
36+
+export.output_name="lfm2_1_2b_8da4w.pte"
37+
```
38+
### Example run
39+
With ExecuTorch pybindings:
40+
```
41+
python -m examples.models.llama.runner.native \
42+
--model lfm2_700m \
43+
--pte lfm2_700m_8da4w.pte \
44+
--tokenizer ~/.cache/huggingface/hub/models--LiquidAI--LFM2-700M/snapshots/ab260293733f05dd4ce22399bea1cae2cf9b272d/tokenizer.json \
45+
--tokenizer_config ~/.cache/huggingface/hub/models--LiquidAI--LFM2-700M/snapshots/ab260293733f05dd4ce22399bea1cae2cf9b272d/tokenizer_config.json \
46+
--prompt "<|startoftext|><|im_start|>user\nWho are you?<|im_end|>\n<|im_start|>assistant\n" \
47+
--params examples/models/lfm2/config/lfm2_700m_config.json \
48+
--max_len 128 \
49+
-kv \
50+
--temperature 0.3
51+
```
52+
53+
With ExecuTorch's sample c++ runner (see the Llama README's [Step 3: Run on your computer to validate](../llama/README.md#step-3-run-on-your-computer-to-validate) to build the runner):
54+
```
55+
cmake-out/examples/models/llama/llama_main \
56+
--model_path lfm2_700m_8da4w.pte \
57+
--tokenizer_path ~/.cache/huggingface/hub/models--LiquidAI--LFM2-700M/snapshots/ab260293733f05dd4ce22399bea1cae2cf9b272d/tokenizer.json \
58+
--prompt="<|startoftext|><|im_start|>user\nWho are you?<|im_end|>\n<|im_start|>assistant\n" \
59+
--temperature 0.3
60+
```
61+
62+
To run the model on an example iOS or Android app, see the Llama README's [Step 5: Build Mobile apps](../llama/README.md#step-5-build-mobile-apps) section.

examples/models/lfm2/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from executorch.examples.models.lfm2.convert_weights import convert_weights
2+
3+
__all__ = [
4+
"convert_weights",
5+
]
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
{
2+
"dim": 2048,
3+
"ffn_dim_multiplier": 1,
4+
"hidden_dim": 8192,
5+
"n_heads": 32,
6+
"n_kv_heads": 8,
7+
"n_layers": 16,
8+
"norm_eps": 1e-5,
9+
"rope_theta": 1000000.0,
10+
"use_scaled_rope": false,
11+
"vocab_size": 65536,
12+
"use_hf_rope": true,
13+
"use_qk_norm": true,
14+
"qk_norm_before_rope": true,
15+
"layer_types": [
16+
"conv",
17+
"conv",
18+
"full_attention",
19+
"conv",
20+
"conv",
21+
"full_attention",
22+
"conv",
23+
"conv",
24+
"full_attention",
25+
"conv",
26+
"full_attention",
27+
"conv",
28+
"full_attention",
29+
"conv",
30+
"full_attention",
31+
"conv",
32+
"conv"
33+
]
34+
}
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
{
2+
"dim": 1024,
3+
"ffn_dim_multiplier": 1,
4+
"hidden_dim": 4608,
5+
"n_heads": 16,
6+
"n_kv_heads": 8,
7+
"n_layers": 16,
8+
"norm_eps": 1e-5,
9+
"rope_theta": 1000000.0,
10+
"use_scaled_rope": false,
11+
"vocab_size": 65536,
12+
"use_hf_rope": true,
13+
"use_qk_norm": true,
14+
"qk_norm_before_rope": true,
15+
"layer_types": [
16+
"conv",
17+
"conv",
18+
"full_attention",
19+
"conv",
20+
"conv",
21+
"full_attention",
22+
"conv",
23+
"conv",
24+
"full_attention",
25+
"conv",
26+
"full_attention",
27+
"conv",
28+
"full_attention",
29+
"conv",
30+
"full_attention",
31+
"conv",
32+
"conv"
33+
]
34+
}
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
{
2+
"dim": 1536,
3+
"ffn_dim_multiplier": 1,
4+
"hidden_dim": 6912,
5+
"n_heads": 24,
6+
"n_kv_heads": 8,
7+
"n_layers": 16,
8+
"norm_eps": 1e-5,
9+
"rope_theta": 1000000.0,
10+
"use_scaled_rope": false,
11+
"vocab_size": 65536,
12+
"use_hf_rope": true,
13+
"use_qk_norm": true,
14+
"qk_norm_before_rope": true,
15+
"layer_types": [
16+
"conv",
17+
"conv",
18+
"full_attention",
19+
"conv",
20+
"conv",
21+
"full_attention",
22+
"conv",
23+
"conv",
24+
"full_attention",
25+
"conv",
26+
"full_attention",
27+
"conv",
28+
"full_attention",
29+
"conv",
30+
"full_attention",
31+
"conv",
32+
"conv"
33+
]
34+
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
base:
2+
metadata: '{"get_bos_id": 1, "get_eos_ids":[7]}'
3+
4+
model:
5+
use_kv_cache: True
6+
use_sdpa_with_kv_cache: True
7+
dtype_override: fp32
8+
9+
backend:
10+
xnnpack:
11+
enabled: True
12+
extended_ops: True
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
base:
2+
metadata: '{"get_bos_id": 1, "get_eos_ids":[7]}'
3+
4+
model:
5+
use_kv_cache: True
6+
use_sdpa_with_kv_cache: True
7+
dtype_override: fp32
8+
9+
quantization:
10+
qmode: 8da4w
11+
12+
backend:
13+
xnnpack:
14+
enabled: True
15+
extended_ops: True
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import os
2+
from typing import Dict
3+
4+
import torch
5+
from safetensors.torch import load_file
6+
7+
from torchtune.models.convert_weights import get_mapped_key
8+
9+
_LFM_2_TO_META = {
10+
"model.embed_tokens.weight": "tok_embeddings.weight",
11+
"model.embedding_norm.weight": "norm.weight",
12+
"model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight",
13+
"model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight",
14+
"model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight",
15+
"model.layers.{}.self_attn.out_proj.weight": "layers.{}.attention.wo.weight",
16+
"model.layers.{}.self_attn.k_layernorm.weight": "layers.{}.attention.k_norm_fn.weight",
17+
"model.layers.{}.self_attn.q_layernorm.weight": "layers.{}.attention.q_norm_fn.weight",
18+
"model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight",
19+
"model.layers.{}.operator_norm.weight": "layers.{}.attention_norm.weight",
20+
}
21+
22+
23+
def lfm_2_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
24+
"""
25+
Convert a state dict from LFM2 HF format to Meta's format. This function
26+
doesn't handle any sharding or splitting of state dicts. It follows the
27+
state_dict IN -> state_dict OUT pattern.
28+
29+
Args:
30+
state_dict (Dict[str, torch.Tensor]): State dict in LFM2 HF format.
31+
32+
Returns:
33+
Dict[str, torch.Tensor]: State dict in Meta's format.
34+
"""
35+
converted_state_dict = {}
36+
37+
for key, value in state_dict.items():
38+
try:
39+
new_key = get_mapped_key(key, _LFM_2_TO_META)
40+
except:
41+
new_key = key.removeprefix("model.")
42+
43+
# split in_proj
44+
if new_key.endswith(".conv.in_proj.weight"):
45+
for name, split_value in zip(
46+
["B_proj", "C_proj", "x_proj"], torch.chunk(value, 3, dim=0)
47+
):
48+
converted_state_dict[new_key.replace("in_proj", name)] = split_value
49+
else:
50+
converted_state_dict[new_key] = value
51+
52+
# If lm_head.weight is not present in state dict, assume tied embeddings
53+
if "lm_head.weight" not in state_dict:
54+
converted_state_dict["output.weight"] = converted_state_dict[
55+
"tok_embeddings.weight"
56+
]
57+
58+
return converted_state_dict
59+
60+
61+
def load_checkpoint(input_dir: str) -> Dict:
62+
print("Loading checkpoint from safetensors directory")
63+
state_dict = load_file(os.path.join(input_dir, "model.safetensors"))
64+
return state_dict
65+
66+
67+
def convert_weights(input_dir: str, output_file: str) -> None:
68+
print("Loading checkpoint...")
69+
sd = load_checkpoint(input_dir)
70+
print("Converting checkpoint...")
71+
sd = lfm_2_to_meta(sd)
72+
print("Saving checkpoint...")
73+
torch.save(sd, output_file)
74+
print("Done.")

examples/models/lfm2/short_conv.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
import torch
2+
from executorch.examples.models.llama.attention import ForwardOptions
3+
from executorch.examples.models.llama.feed_forward import FeedForward
4+
5+
from executorch.examples.models.llama.norm import RMSNorm
6+
from torch import nn
7+
8+
9+
class ShortConv(nn.Module):
10+
def __init__(
11+
self,
12+
dim: int,
13+
L_cache: int = 3,
14+
bias: bool = False,
15+
device: torch.device = None,
16+
dtype: torch.dtype = None,
17+
):
18+
super().__init__()
19+
self.dim = dim
20+
self.L_cache = L_cache
21+
self.device = device
22+
self.dtype = dtype
23+
self.bias = bias
24+
25+
self.conv = nn.Conv1d(
26+
dim,
27+
dim,
28+
kernel_size=L_cache,
29+
padding=0, ## we don't need padding since we handle it manually
30+
groups=dim,
31+
bias=bias,
32+
)
33+
34+
conv_state = torch.zeros(
35+
1, ## batch size is assumed to be 1 for now
36+
dim,
37+
L_cache - 1,
38+
device="cpu",
39+
)
40+
self.register_buffer("conv_state", conv_state)
41+
42+
## better performance in Executorch with separate projections
43+
self.B_proj = nn.Linear(dim, dim, bias=bias)
44+
self.C_proj = nn.Linear(dim, dim, bias=bias)
45+
self.x_proj = nn.Linear(dim, dim, bias=bias)
46+
47+
self.out_proj = nn.Linear(dim, dim, bias=bias)
48+
49+
def forward(self, x: torch.Tensor) -> torch.Tensor:
50+
batch_size, seqlen, dim = x.size()
51+
assert batch_size == 1, "batch_size must be 1"
52+
53+
B = self.B_proj(x).transpose(-1, -2) # (batch_size, dim, seq_len)
54+
C = self.C_proj(x).transpose(-1, -2) # (batch_size, dim, seq_len)
55+
x = self.x_proj(x).transpose(-1, -2) # (batch_size, dim, seq_len)
56+
57+
Bx = B * x # (batch_size, dim, seq_len)
58+
59+
## This is where we handle padding
60+
## By default, the conv_state is initialized to 0.
61+
# So, assuming prefill is done on an empty cache, concatenating conv_state to the beginning of the sequence acts similary to
62+
## using nn.Conv1d(padding=L_cache-1) (for prefill) without no manual padding.
63+
## However, the manual padding has the added benefit of being correct during decode, when the cache is not initialized to 0.
64+
Bx = torch.cat(
65+
[self.conv_state, Bx], dim=-1
66+
) # (batch_size, dim, seq_len + L_cache - 1)
67+
68+
## Update the conv_state
69+
new_conv_state = Bx[
70+
..., -(self.L_cache - 1) :
71+
] # (batch_size, dim, L_cache - 1)
72+
with torch.no_grad():
73+
self.conv_state.copy_(new_conv_state)
74+
75+
conv_out = self.conv(Bx)[..., : x.size(-1)] # (batch_size, dim, seq_len)
76+
y = C * conv_out # (batch_size, dim, seq_len)
77+
78+
y = y.transpose(-1, -2) # (batch_size, seq_len, dim)
79+
y = y.contiguous() # (batch_size, seq_len, dim)
80+
y = self.out_proj(y) # (batch_size, seq_len, dim)
81+
return y
82+
83+
def reset_cache(self):
84+
self.conv_state.zero_()
85+
86+
87+
class ShortConvBlock(nn.Module):
88+
def __init__(self, dim: int, hidden_dim: int, norm_eps: float):
89+
super().__init__()
90+
self.L_cache = 3 # hardcode 3 for now
91+
self.conv = ShortConv(dim, self.L_cache, bias=False)
92+
self.feed_forward = FeedForward(dim, hidden_dim)
93+
self.ffn_norm = RMSNorm(dim, norm_eps)
94+
# use attention_norm norm instead of operator_norm to unify with TransformerBlock
95+
self.attention_norm = RMSNorm(dim, norm_eps)
96+
97+
def forward(
98+
self,
99+
x,
100+
freqs_cos=None,
101+
freqs_sin=None,
102+
_unused_attn_options: ForwardOptions = None,
103+
): # x: 1xN
104+
h = self.conv.forward(self.attention_norm(x))
105+
h = x + h
106+
out = h + self.feed_forward(self.ffn_norm(h))
107+
return out, None
108+
109+
def reset_cache(self):
110+
self.conv.reset_cache()

0 commit comments

Comments
 (0)