Skip to content

Commit 0ea6faa

Browse files
committed
Address PR feedback
1 parent abb5236 commit 0ea6faa

File tree

2 files changed

+19
-22
lines changed

2 files changed

+19
-22
lines changed

examples/models/lfm2/README.md

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,8 @@ With ExecuTorch's sample c++ runner (see the Llama README's [Step 3: Run on your
5555
cmake-out/examples/models/llama/llama_main \
5656
--model_path lfm2_700m_8da4w.pte \
5757
--tokenizer_path ~/.cache/huggingface/hub/models--LiquidAI--LFM2-700M/snapshots/ab260293733f05dd4ce22399bea1cae2cf9b272d/tokenizer.json \
58-
--prompt="<|startoftext|><|im_start|>user\nWho are you?<|im_end|>\n<|im_start|>assistant\n"
58+
--prompt="<|startoftext|><|im_start|>user\nWho are you?<|im_end|>\n<|im_start|>assistant\n" \
59+
--temperature 0.3
5960
```
6061

6162
To run the model on an example iOS or Android app, see the Llama README's [Step 5: Build Mobile apps](../llama/README.md#step-5-build-mobile-apps) section.
62-
63-
### FAQ
64-
For more help with exporting or running this model, feel free to ask in our [discord channel](https://discord.gg/UEjkY9Zs).

examples/models/lfm2/convert_weights.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,41 +9,40 @@
99

1010
from torchtune.models.convert_weights import get_mapped_key
1111

12-
_LFM_2_FROM_META = {
13-
"tok_embeddings.weight": "model.embed_tokens.weight",
14-
"norm.weight": "model.embedding_norm.weight",
12+
_LFM_2_TO_META = {
13+
"model.embed_tokens.weight": "tok_embeddings.weight",
14+
"model.embedding_norm.weight": "norm.weight",
1515

16-
"layers.{}.attention.wk.weight": "model.layers.{}.self_attn.k_proj.weight",
17-
"layers.{}.attention.wq.weight": "model.layers.{}.self_attn.q_proj.weight",
18-
"layers.{}.attention.wv.weight": "model.layers.{}.self_attn.v_proj.weight",
19-
"layers.{}.attention.wo.weight": "model.layers.{}.self_attn.out_proj.weight",
20-
"layers.{}.attention.k_norm_fn.weight": "model.layers.{}.self_attn.k_layernorm.weight",
21-
"layers.{}.attention.q_norm_fn.weight": "model.layers.{}.self_attn.q_layernorm.weight",
16+
"model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight",
17+
"model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight",
18+
"model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight",
19+
"model.layers.{}.self_attn.out_proj.weight": "layers.{}.attention.wo.weight",
20+
"model.layers.{}.self_attn.k_layernorm.weight": "layers.{}.attention.k_norm_fn.weight",
21+
"model.layers.{}.self_attn.q_layernorm.weight": "layers.{}.attention.q_norm_fn.weight",
2222

23-
"layers.{}.ffn_norm.weight": "model.layers.{}.post_attention_layernorm.weight",
23+
"model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight",
2424

25-
"layers.{}.attention_norm.weight": "model.layers.{}.operator_norm.weight",
25+
"model.layers.{}.operator_norm.weight": "layers.{}.attention_norm.weight",
2626
}
2727

2828

29-
def lfm_2_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
29+
def lfm_2_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
3030
"""
31-
Convert a state dict from torchtune's format to Meta's format. This function
31+
Convert a state dict from LFM2 HF format to Meta's format. This function
3232
doesn't handle any sharding or splitting of state dicts. It follows the
3333
state_dict IN -> state_dict OUT pattern.
3434
3535
Args:
36-
state_dict (Dict[str, torch.Tensor]): State dict in torchtune's format.
36+
state_dict (Dict[str, torch.Tensor]): State dict in LFM2 HF format.
3737
3838
Returns:
3939
Dict[str, torch.Tensor]: State dict in Meta's format.
4040
"""
4141
converted_state_dict = {}
42-
inverted_mapping_dict = {v: k for k, v in _LFM_2_FROM_META.items()}
4342

4443
for key, value in state_dict.items():
4544
try:
46-
new_key = get_mapped_key(key, inverted_mapping_dict)
45+
new_key = get_mapped_key(key, _LFM_2_TO_META)
4746
except:
4847
new_key = key.removeprefix("model.")
4948

@@ -54,7 +53,7 @@ def lfm_2_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.T
5453
else:
5554
converted_state_dict[new_key] = value
5655

57-
# If lm_head.weight is not present in state dict, assume tied embeddings (e.g., 0.6b and 4b models)
56+
# If lm_head.weight is not present in state dict, assume tied embeddings
5857
if "lm_head.weight" not in state_dict:
5958
converted_state_dict["output.weight"] = converted_state_dict[
6059
"tok_embeddings.weight"
@@ -73,7 +72,7 @@ def convert_weights(input_dir: str, output_file: str) -> None:
7372
print("Loading checkpoint...")
7473
sd = load_checkpoint(input_dir)
7574
print("Converting checkpoint...")
76-
sd = lfm_2_tune_to_meta(sd)
75+
sd = lfm_2_to_meta(sd)
7776
print("Saving checkpoint...")
7877
torch.save(sd, output_file)
7978
print("Done.")

0 commit comments

Comments
 (0)