|
| 1 | +import enum |
| 2 | +import logging |
| 3 | +import pathlib |
| 4 | + |
| 5 | +import yaml |
| 6 | +from transformers import AutoModelForCausalLM |
| 7 | + |
| 8 | +from fast_llm.engine.config_utils.runnable import RunnableConfig |
| 9 | + |
| 10 | +from fast_llm.config import Config, Field, config_class # isort:skip |
| 11 | + |
| 12 | +logger = logging.getLogger(__name__) |
| 13 | + |
| 14 | + |
| 15 | +class PredictionHeadInitMethod(str, enum.Enum): |
| 16 | + from_existing = "from_existing" |
| 17 | + # from_scratch = "from_scratch" |
| 18 | + |
| 19 | + |
| 20 | +@config_class() |
| 21 | +class AddPredictionHeadsConfig(RunnableConfig): |
| 22 | + hf_checkpoint: pathlib.Path = Field() |
| 23 | + output_dir: pathlib.Path = Field() |
| 24 | + num_prediction_heads: int = Field() |
| 25 | + prediction_head_init_method: PredictionHeadInitMethod = Field() |
| 26 | + prediction_head_init_std: float = Field(default=0.0) |
| 27 | + |
| 28 | + def _validate(self): |
| 29 | + super()._validate() |
| 30 | + assert self.prediction_head_init_method == PredictionHeadInitMethod.from_existing |
| 31 | + |
| 32 | + def run(self): |
| 33 | + logger.info(f"Loading {self.hf_checkpoint}...") |
| 34 | + model = AutoModelForCausalLM.from_pretrained(self.hf_checkpoint) |
| 35 | + assert model.config.model_type in [ |
| 36 | + "llama", |
| 37 | + "mistral", |
| 38 | + "apriel", |
| 39 | + ], f"Model type {model.config.model_type} is not supported" |
| 40 | + # We convert the models to MTP-Llama. It does not support sliding window. |
| 41 | + if model.config.model_type == "mistral": |
| 42 | + assert model.config.sliding_window is None |
| 43 | + model.config.mlp_bias = False |
| 44 | + state_dict = model.state_dict() |
| 45 | + |
| 46 | + logger.info(f"Adding Prediction Heads to {self.hf_checkpoint}...") |
| 47 | + |
| 48 | + # New prediction-heads' transformer layers |
| 49 | + hf_mtp_head_base_name = "model.mtp_heads" |
| 50 | + # Last layer is the first head |
| 51 | + last_layer_base_name = f"model.layers.{model.config.num_hidden_layers - 1}" |
| 52 | + for i in range(self.num_prediction_heads - 1): |
| 53 | + for w in ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj", "self_attn.o_proj"]: |
| 54 | + state_dict[f"{hf_mtp_head_base_name}.{i}.{w}.weight"] = state_dict[ |
| 55 | + f"{last_layer_base_name}.{w}.weight" |
| 56 | + ].clone() |
| 57 | + # Llama: no bias in attention |
| 58 | + assert f"{last_layer_base_name}.{w}.bias" not in state_dict, "Bias found in state_dict" |
| 59 | + for w in ["input_layernorm", "post_attention_layernorm"]: |
| 60 | + # RMS norm: no bias |
| 61 | + state_dict[f"{hf_mtp_head_base_name}.{i}.{w}.weight"] = state_dict[ |
| 62 | + f"{last_layer_base_name}.{w}.weight" |
| 63 | + ].clone() |
| 64 | + assert f"{last_layer_base_name}.{w}.bias" not in state_dict, "Bias found in state_dict" |
| 65 | + for w in ["mlp.gate_proj", "mlp.up_proj", "mlp.down_proj"]: |
| 66 | + state_dict[f"{hf_mtp_head_base_name}.{i}.{w}.weight"] = state_dict[ |
| 67 | + f"{last_layer_base_name}.{w}.weight" |
| 68 | + ].clone() |
| 69 | + if model.config.mlp_bias: |
| 70 | + state_dict[f"{hf_mtp_head_base_name}.{i}.{w}.bias"] = state_dict[ |
| 71 | + f"{last_layer_base_name}.{w}.bias" |
| 72 | + ].clone() |
| 73 | + else: |
| 74 | + assert f"{last_layer_base_name}.{w}.bias" not in state_dict, "Bias found in state_dict" |
| 75 | + |
| 76 | + # Layer norms |
| 77 | + hf_mtp_norm_base_name = "model.mtp_norms" |
| 78 | + state_dict[f"{hf_mtp_norm_base_name}.0.weight"] = state_dict.pop(f"model.norm.weight") |
| 79 | + assert f"model.norm.bias" not in state_dict, "Bias found in state_dict" |
| 80 | + for i in range(1, self.num_prediction_heads): |
| 81 | + state_dict[f"{hf_mtp_norm_base_name}.{i}.weight"] = state_dict[f"{hf_mtp_norm_base_name}.0.weight"].clone() |
| 82 | + |
| 83 | + # Adjust config |
| 84 | + model.config.prediction_heads = self.num_prediction_heads |
| 85 | + model.config.auto_map = { |
| 86 | + "AutoConfig": "configuration_mtp_llama.MTPLlamaConfig", |
| 87 | + "AutoModel": "modeling_mtp_llama.MTPLlamaModel", |
| 88 | + "AutoModelForCausalLM": "modeling_mtp_llama.MTPLlamaForCausalLM", |
| 89 | + } |
| 90 | + # model.config.architectures = ["MTPLlamaForCausalLM"] |
| 91 | + |
| 92 | + # Save model |
| 93 | + logger.info(f"Saving model to {self.output_dir}...") |
| 94 | + model.save_pretrained(self.output_dir, state_dict=state_dict) |
| 95 | + logger.warning( |
| 96 | + f"WARNING: Model architecture needs to be updated manually to MTPLlamaForCausalLM in {self.output_dir}/config.json" |
| 97 | + ) |
| 98 | + logger.warning( |
| 99 | + f"WARNING: Model-type needs to be updated manually to mtp_llama in {self.output_dir}/config.json" |
| 100 | + ) |
| 101 | + |
| 102 | + # Save surgery config as yaml |
| 103 | + yaml.safe_dump(self.to_serialized(), (self.output_dir / "surgery_config.yaml").open("w")) |
| 104 | + logger.info("Done!") |
| 105 | + |
| 106 | + |
| 107 | +if __name__ == "__main__": |
| 108 | + AddPredictionHeadsConfig.parse_and_run() |
0 commit comments