Skip to content

Commit b4fdb38

Browse files
authored
Script to add MTP heads to existing model (#284)
1 parent 0e1c23c commit b4fdb38

File tree

1 file changed

+108
-0
lines changed

1 file changed

+108
-0
lines changed

tools/transformer_add_mtp_heads.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
import enum
2+
import logging
3+
import pathlib
4+
5+
import yaml
6+
from transformers import AutoModelForCausalLM
7+
8+
from fast_llm.engine.config_utils.runnable import RunnableConfig
9+
10+
from fast_llm.config import Config, Field, config_class # isort:skip
11+
12+
logger = logging.getLogger(__name__)
13+
14+
15+
class PredictionHeadInitMethod(str, enum.Enum):
16+
from_existing = "from_existing"
17+
# from_scratch = "from_scratch"
18+
19+
20+
@config_class()
21+
class AddPredictionHeadsConfig(RunnableConfig):
22+
hf_checkpoint: pathlib.Path = Field()
23+
output_dir: pathlib.Path = Field()
24+
num_prediction_heads: int = Field()
25+
prediction_head_init_method: PredictionHeadInitMethod = Field()
26+
prediction_head_init_std: float = Field(default=0.0)
27+
28+
def _validate(self):
29+
super()._validate()
30+
assert self.prediction_head_init_method == PredictionHeadInitMethod.from_existing
31+
32+
def run(self):
33+
logger.info(f"Loading {self.hf_checkpoint}...")
34+
model = AutoModelForCausalLM.from_pretrained(self.hf_checkpoint)
35+
assert model.config.model_type in [
36+
"llama",
37+
"mistral",
38+
"apriel",
39+
], f"Model type {model.config.model_type} is not supported"
40+
# We convert the models to MTP-Llama. It does not support sliding window.
41+
if model.config.model_type == "mistral":
42+
assert model.config.sliding_window is None
43+
model.config.mlp_bias = False
44+
state_dict = model.state_dict()
45+
46+
logger.info(f"Adding Prediction Heads to {self.hf_checkpoint}...")
47+
48+
# New prediction-heads' transformer layers
49+
hf_mtp_head_base_name = "model.mtp_heads"
50+
# Last layer is the first head
51+
last_layer_base_name = f"model.layers.{model.config.num_hidden_layers - 1}"
52+
for i in range(self.num_prediction_heads - 1):
53+
for w in ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj", "self_attn.o_proj"]:
54+
state_dict[f"{hf_mtp_head_base_name}.{i}.{w}.weight"] = state_dict[
55+
f"{last_layer_base_name}.{w}.weight"
56+
].clone()
57+
# Llama: no bias in attention
58+
assert f"{last_layer_base_name}.{w}.bias" not in state_dict, "Bias found in state_dict"
59+
for w in ["input_layernorm", "post_attention_layernorm"]:
60+
# RMS norm: no bias
61+
state_dict[f"{hf_mtp_head_base_name}.{i}.{w}.weight"] = state_dict[
62+
f"{last_layer_base_name}.{w}.weight"
63+
].clone()
64+
assert f"{last_layer_base_name}.{w}.bias" not in state_dict, "Bias found in state_dict"
65+
for w in ["mlp.gate_proj", "mlp.up_proj", "mlp.down_proj"]:
66+
state_dict[f"{hf_mtp_head_base_name}.{i}.{w}.weight"] = state_dict[
67+
f"{last_layer_base_name}.{w}.weight"
68+
].clone()
69+
if model.config.mlp_bias:
70+
state_dict[f"{hf_mtp_head_base_name}.{i}.{w}.bias"] = state_dict[
71+
f"{last_layer_base_name}.{w}.bias"
72+
].clone()
73+
else:
74+
assert f"{last_layer_base_name}.{w}.bias" not in state_dict, "Bias found in state_dict"
75+
76+
# Layer norms
77+
hf_mtp_norm_base_name = "model.mtp_norms"
78+
state_dict[f"{hf_mtp_norm_base_name}.0.weight"] = state_dict.pop(f"model.norm.weight")
79+
assert f"model.norm.bias" not in state_dict, "Bias found in state_dict"
80+
for i in range(1, self.num_prediction_heads):
81+
state_dict[f"{hf_mtp_norm_base_name}.{i}.weight"] = state_dict[f"{hf_mtp_norm_base_name}.0.weight"].clone()
82+
83+
# Adjust config
84+
model.config.prediction_heads = self.num_prediction_heads
85+
model.config.auto_map = {
86+
"AutoConfig": "configuration_mtp_llama.MTPLlamaConfig",
87+
"AutoModel": "modeling_mtp_llama.MTPLlamaModel",
88+
"AutoModelForCausalLM": "modeling_mtp_llama.MTPLlamaForCausalLM",
89+
}
90+
# model.config.architectures = ["MTPLlamaForCausalLM"]
91+
92+
# Save model
93+
logger.info(f"Saving model to {self.output_dir}...")
94+
model.save_pretrained(self.output_dir, state_dict=state_dict)
95+
logger.warning(
96+
f"WARNING: Model architecture needs to be updated manually to MTPLlamaForCausalLM in {self.output_dir}/config.json"
97+
)
98+
logger.warning(
99+
f"WARNING: Model-type needs to be updated manually to mtp_llama in {self.output_dir}/config.json"
100+
)
101+
102+
# Save surgery config as yaml
103+
yaml.safe_dump(self.to_serialized(), (self.output_dir / "surgery_config.yaml").open("w"))
104+
logger.info("Done!")
105+
106+
107+
if __name__ == "__main__":
108+
AddPredictionHeadsConfig.parse_and_run()

0 commit comments

Comments
 (0)