Skip to content

Commit ed3136f

Browse files
dongwang218mreso
andauthored
Update hf weight conversion script to llama 3 (meta-llama#551)
Co-authored-by: Matthias Reso <[email protected]>
1 parent f6617fb commit ed3136f

File tree

3 files changed

+25
-17
lines changed

3 files changed

+25
-17
lines changed

src/llama_recipes/utils/hf_llama_conversion/README.md renamed to src/llama_recipes/tools/README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,16 @@ This is the reverse conversion for `convert_llama_weights_to_hf.py` script from
77
- Copy file params.json from the official llama download into that directory.
88
- Run the conversion script. `model-path` can be a Hugging Face hub model or a local hf model directory.
99
```
10-
python -m llama_recipes.tools.convert_hf_weights_to_llama --model-path meta-llama/Llama-2-70b-chat-hf --output-dir test70B --model-size 70B
10+
python -m llama_recipes.tools.convert_hf_weights_to_llama --model-path meta-llama/Meta-Llama-3-70B-Instruct --output-dir test70B --model-size 70B
1111
```
1212

1313
## Step 1: Run inference
14-
Checkout the official llama inference [repo](https://github.com/facebookresearch/llama). Test using chat or text completion.
14+
Checkout the official llama 3 inference [repo](https://github.com/meta-llama/llama3). Test using chat or text completion.
1515
```
16-
torchrun --nproc_per_node 8 example_chat_completion.py --ckpt_dir ./test70B --tokenizer_path ${llama_2_dir}/tokenizer.model
16+
torchrun --nproc_per_node 8 example_chat_completion.py --ckpt_dir ./test70B --tokenizer_path ${llama_3_dir}/tokenizer.model
1717
```
1818

1919
For validation, please compare the converted weights with official llama 2 weights
2020
```
21-
python compare_llama_weights.py test70B ${llama_2_70b_chat_dir}
21+
python compare_llama_weights.py test70B ${Llama-3-70B-Instruct_dir}
2222
```

src/llama_recipes/utils/hf_llama_conversion/compare_llama_weights.py renamed to src/llama_recipes/tools/compare_llama_weights.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,23 +28,25 @@ def main() -> None:
2828
assert len(one) == len(
2929
two
3030
), "shard should have the same length: {} != {}".format(len(one), len(two))
31+
one = sorted(one.items(), key=lambda x: x[0])
32+
two = sorted(two.items(), key=lambda x: x[0])
3133

32-
for _, (v, w) in enumerate(zip(one.items(), two.items())):
34+
for _, (v, w) in enumerate(zip(one, two)):
3335
assert v[0] == w[0], "{} != {}".format(v[0], w[0])
3436
assert v[1].shape == w[1].shape, "tensor {} shape {} != {}".format(
3537
v[0], v[1].shape, w[1].shape
3638
)
3739

3840
delta = (v[1] - w[1]).abs().max().item()
39-
deltas.append((i, v[0], delta))
41+
deltas.append((i, v[0], delta, w[1].abs().mean().item()))
4042
del one
4143
del two
4244
gc.collect()
4345

44-
deltas = sorted(deltas, key=lambda x: x[-1], reverse=True)
46+
deltas = sorted(deltas, key=lambda x: x[-2], reverse=True)
4547
print("Top 10 largest deltas:")
46-
for i, k, v in deltas[:10]:
47-
print(f" shard {i} {k}: {v}")
48+
for i, k, delta, value in deltas[:10]:
49+
print(f" shard {i} {k}: {delta} vs {value}")
4850

4951

5052
if __name__ == "__main__":

src/llama_recipes/tools/convert_hf_weights_to_llama.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
NUM_SHARDS = {
1414
"7B": 1,
15+
"8B": 1,
1516
"13B": 2,
1617
"34B": 4,
1718
"30B": 4,
@@ -30,15 +31,12 @@ def write_model(model_path, model_size, output_base_path):
3031
n_heads_per_shard = n_heads // num_shards
3132
dim = params["dim"]
3233
dims_per_head = dim // n_heads
33-
base = 10000.0
34-
inv_freq = (
35-
1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
36-
).to(dtype)
34+
llama_version = 3 if params.get("vocab_size") == 128256 else 2
3735

3836
if "n_kv_heads" in params:
3937
num_key_value_heads = params["n_kv_heads"] # for GQA / MQA
40-
num_local_key_value_heads = n_heads_per_shard // num_key_value_heads
41-
key_value_dim = dim // num_key_value_heads
38+
num_local_key_value_heads = num_key_value_heads // num_shards
39+
key_value_dim = dims_per_head * num_key_value_heads
4240
else: # compatibility with other checkpoints
4341
num_key_value_heads = n_heads
4442
num_local_key_value_heads = n_heads_per_shard
@@ -72,7 +70,10 @@ def insert_chunk(name: str, tensor: torch.Tensor, dim: int):
7270
for i, tensor in enumerate(tensors):
7371
state_dict[i][name] = tensor.clone()
7472

75-
insert_chunk("tok_embeddings.weight", loaded["model.embed_tokens.weight"], 1)
73+
concat_dim = 0 if llama_version == 3 else 1
74+
insert_chunk(
75+
"tok_embeddings.weight", loaded["model.embed_tokens.weight"], concat_dim
76+
)
7677
insert("norm.weight", loaded["model.norm.weight"])
7778
insert_chunk("output.weight", loaded["lm_head.weight"], 0)
7879

@@ -136,7 +137,12 @@ def insert_chunk(name: str, tensor: torch.Tensor, dim: int):
136137
f"layers.{layer_i}.ffn_norm.weight",
137138
loaded[f"model.layers.{layer_i}.post_attention_layernorm.weight"],
138139
)
139-
insert("rope.freqs", inv_freq)
140+
if llama_version != 3:
141+
base = 10000.0
142+
inv_freq = (
143+
1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
144+
).to(dtype)
145+
insert("rope.freqs", inv_freq)
140146

141147
for i in tqdm(range(num_shards), desc="Saving checkpoint shards"):
142148
torch.save(

0 commit comments

Comments
 (0)