diff --git a/.ci/scripts/test_model.sh b/.ci/scripts/test_model.sh index 8143f9ea9a4..65d04dc3886 100755 --- a/.ci/scripts/test_model.sh +++ b/.ci/scripts/test_model.sh @@ -100,6 +100,14 @@ test_model() { rm "./${MODEL_NAME}.pte" return # Skip running with portable executor runnner since portable doesn't support Qwen's biased linears. fi + if [[ "${MODEL_NAME}" == "phi4_mini" ]]; then + # Install requirements for export_llama + bash examples/models/llama/install_requirements.sh + # Test export_llama script: python3 -m examples.models.llama.export_llama. + "${PYTHON_EXECUTABLE}" -m examples.models.llama.export_llama --model "${MODEL_NAME}" -c examples/models/llama/params/demo_rand_params.pth -p examples/models/phi-4-mini/config.json + run_portable_executor_runner + rm "./${MODEL_NAME}.pte" + fi # Export a basic .pte and run the model. "${PYTHON_EXECUTABLE}" -m examples.portable.scripts.export --model_name="${MODEL_NAME}" "${STRICT}" diff --git a/examples/models/__init__.py b/examples/models/__init__.py index 55f5c449ca2..778138a1802 100644 --- a/examples/models/__init__.py +++ b/examples/models/__init__.py @@ -35,6 +35,7 @@ "llava": ("llava", "LlavaModel"), "efficient_sam": ("efficient_sam", "EfficientSAM"), "qwen2_5": ("qwen2_5", "Qwen2_5Model"), + "phi4_mini": ("phi4_mini", "Phi4MiniModel"), } __all__ = [ diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 3a1f423aa27..6e5eca11d0e 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -93,6 +93,7 @@ "llama3_2", "static_llama", "qwen2_5", + "phi4_mini", ] TORCHTUNE_DEFINED_MODELS = ["llama3_2_vision"] diff --git a/examples/models/llama/model_args.py b/examples/models/llama/model_args.py index 226e0049803..714976e34fe 100644 --- a/examples/models/llama/model_args.py +++ b/examples/models/llama/model_args.py @@ -38,6 +38,7 @@ class ModelArgs: apply_embedding: bool = True # Use embedding inside the transformer apply_output: bool = True # Use output layer (unembedding) inside the transformer use_hf_rope: bool = False # Use HuggingFace's RoPE implementation + partial_rotary_factor: float = 1.0 rope_theta: Optional[float] = ( None # The official name to override self.rope_freq_base. ) diff --git a/examples/models/llama/rope.py b/examples/models/llama/rope.py index e081c442032..02eb564ed76 100644 --- a/examples/models/llama/rope.py +++ b/examples/models/llama/rope.py @@ -134,11 +134,21 @@ def forward( # Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L77 -def hf_precompute_freqs_cis(dim: int, end: int, theta: float): +# and https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_rope_utils.py#L242. +# Current only support non-long rope. +def hf_precompute_freqs_cis( + dim: int, end: int, theta: float, partial_rotary_factor: float = 1.0 +): + # Partial rotary embeddings. + dim = int(dim * partial_rotary_factor) + + # Short factor scaling. freqs = 1.0 / ( theta ** (torch.arange(0, dim, 2, device="cpu", dtype=torch.int64).float() / dim) ) + # TODO: support long factor scaling. + # pyre-ignore Undefined attribute [16]: `float` has no attribute `device`. t = torch.arange(end, device=freqs.device, dtype=torch.int64).type_as( freqs # pyre-ignore @@ -180,8 +190,13 @@ def hf_apply_rotary_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """ cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) + + rotary_dim = cos.shape[-1] + q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] + k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] + + q_embed = torch.cat([(q_rot * cos) + (rotate_half(q_rot) * sin), q_pass], dim=-1) + k_embed = torch.cat([(k_rot * cos) + (rotate_half(k_rot) * sin), k_pass], dim=-1) return q_embed, k_embed @@ -217,7 +232,10 @@ def __init__(self, params: ModelArgs): # Choose the appropriate RoPE implementation if self.params.use_hf_rope: - self.precompute_freqs_cis = hf_precompute_freqs_cis + self.precompute_freqs_cis = partial( + hf_precompute_freqs_cis, + partial_rotary_factor=self.params.partial_rotary_factor, + ) self.apply_rotary_emb = hf_apply_rotary_emb else: self.precompute_freqs_cis = partial( diff --git a/examples/models/phi-4-mini/config.json b/examples/models/phi-4-mini/config.json new file mode 100644 index 00000000000..edce93e59fa --- /dev/null +++ b/examples/models/phi-4-mini/config.json @@ -0,0 +1,15 @@ +{ + "dim": 3072, + "ffn_dim_multiplier": 1, + "hidden_dim": 8192, + "n_heads": 24, + "n_kv_heads": 8, + "n_layers": 32, + "norm_eps": 1e-05, + "rope_theta": 10000.0, + "use_scaled_rope": false, + "vocab_size": 200064, + "use_hf_rope": true, + "partial_rotary_factor": 0.75, + "attention_qkv_bias": false +} diff --git a/examples/models/phi-4-mini/convert_weights.py b/examples/models/phi-4-mini/convert_weights.py new file mode 100644 index 00000000000..c29231d2e4d --- /dev/null +++ b/examples/models/phi-4-mini/convert_weights.py @@ -0,0 +1,88 @@ +import argparse +from typing import Dict + +import torch + +from torchtune.models.convert_weights import get_mapped_key + +from torchtune.training import FullModelHFCheckpointer + + +# Standard _FROM_META weight mapping of Meta weights to TorchTune. +_PHI_4_FROM_META = { + "tok_embeddings.weight": "tok_embeddings.weight", + "norm.weight": "norm.scale", + "layers.{}.attention.wk.weight": "layers.{}.attn.k_proj.weight", + "layers.{}.attention.wq.weight": "layers.{}.attn.q_proj.weight", + "layers.{}.attention.wv.weight": "layers.{}.attn.v_proj.weight", + "layers.{}.attention.wo.weight": "layers.{}.attn.output_proj.weight", + "layers.{}.attention_norm.weight": "layers.{}.sa_norm.scale", + "layers.{}.ffn_norm.weight": "layers.{}.mlp_norm.scale", + "layers.{}.feed_forward.w1.weight": "layers.{}.mlp.w1.weight", + "layers.{}.feed_forward.w2.weight": "layers.{}.mlp.w2.weight", + "layers.{}.feed_forward.w3.weight": "layers.{}.mlp.w3.weight", +} + + +def phi_4_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """ + Convert a state dict from torchtune's format to Meta's format. This function + doesn't handle any sharding or splitting of state dicts. It follows the + state_dict IN -> state_dict OUT pattern. + + Args: + state_dict (Dict[str, torch.Tensor]): State dict in torchtune's format. + + Returns: + Dict[str, torch.Tensor]: State dict in Meta's format. + """ + converted_state_dict = {} + inverted_mapping_dict = {v: k for k, v in _PHI_4_FROM_META.items()} + + for key, value in state_dict.items(): + new_key = get_mapped_key(key, inverted_mapping_dict) + converted_state_dict[new_key] = value + + # Input and output embeddings are tied. + converted_state_dict["output.weight"] = converted_state_dict[ + "tok_embeddings.weight" + ] + + return converted_state_dict + + +def main(): + parser = argparse.ArgumentParser( + description="Convert Phi-4-mini weights to Meta format." + ) + parser.add_argument( + "input_dir", + type=str, + help="Path to directory containing checkpoint files", + ) + parser.add_argument("output", type=str, help="Path to the output checkpoint") + + args = parser.parse_args() + + checkpointer = FullModelHFCheckpointer( + checkpoint_dir=args.input_dir, + checkpoint_files=[ + "model-00001-of-00002.safetensors", + "model-00002-of-00002.safetensors", + ], + output_dir=".", + model_type="PHI3_MINI", + ) + + print("Loading checkpoint...") + sd = checkpointer.load_checkpoint() + + print("Converting checkpoint...") + sd = phi_4_tune_to_meta(sd["model"]) + + torch.save(sd, args.output) + print(f"Checkpoint saved to {args.output}") + + +if __name__ == "__main__": + main()