Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions examples/models/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

# pyre-unsafe

import json
import os
from pathlib import Path
from typing import Any, Dict, Optional

Expand Down Expand Up @@ -74,3 +76,39 @@ def get_checkpoint_dtype(checkpoint: Dict[str, Any]) -> Optional[torch.dtype]:
f"Mixed dtype model. Dtype of {first_key}: {first.dtype}. Mismatches in the checkpoint: {mismatched_dtypes}"
)
return dtype


def load_checkpoint_from_pytorch_model(input_dir: str) -> Dict:
index_path = os.path.join(input_dir, "pytorch_model.bin.index.json")
if os.path.exists(index_path):
# Sharded checkpoint.
with open(index_path, "r") as f:
index = json.load(f)
weight_map = index["weight_map"]
checkpoint_shards = sorted(set(weight_map.values()))

# Load all the shards into memory
shard_to_weights = {}
for shard in checkpoint_shards:
shard_to_weights[shard] = torch.load(
os.path.join(input_dir, shard),
weights_only=True,
map_location=torch.device("cpu"),
)

# Merge tensors into consolidated state dict.
merged_state_dict = {}
for weight_name, shard in weight_map.items():
tensor = shard_to_weights[shard][weight_name]
merged_state_dict[weight_name] = tensor
return merged_state_dict

# Single checkpoint
model_path = os.path.join(input_dir, "pytorch_model.bin")
if os.path.exists(model_path):
state_dict = torch.load(
model_path, weights_only=True, map_location=torch.device("cpu")
)
return state_dict

raise FileNotFoundError(f"Could not find pytorch_model checkpoint in {input_dir}")
18 changes: 6 additions & 12 deletions examples/models/phi_4_mini/convert_weights.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import argparse
import os
from typing import Dict

import torch
from executorch.examples.models.checkpoint import load_checkpoint_from_pytorch_model

from torchtune.models.convert_weights import get_mapped_key

Expand Down Expand Up @@ -87,10 +87,8 @@ def phi_4_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.T
Convert a state dict from torchtune's format to Meta's format. This function
doesn't handle any sharding or splitting of state dicts. It follows the
state_dict IN -> state_dict OUT pattern.

Args:
state_dict (Dict[str, torch.Tensor]): State dict in torchtune's format.

Returns:
Dict[str, torch.Tensor]: State dict in Meta's format.
"""
Expand All @@ -105,14 +103,15 @@ def phi_4_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.T
converted_state_dict["output.weight"] = converted_state_dict[
"tok_embeddings.weight"
]

return converted_state_dict


def convert_weights(input_dir_or_checkpoint: str, output_file: str) -> None:
# If input_dir_or_checkpoint is a directory downloaded from HF, FullModelHFCheckpointer is used to extract the state dict
# If input_dir_or_checkpoint is a checkpoint (from eager model model), it is loaded directly
if os.path.isdir(input_dir_or_checkpoint):
try:
sd = load_checkpoint_from_pytorch_model(input_dir_or_checkpoint)
print("Converting checkpoint...")
sd = phi_4_hf_to_meta(sd)
except FileNotFoundError:
checkpointer = FullModelHFCheckpointer(
checkpoint_dir=input_dir_or_checkpoint,
checkpoint_files=[
Expand All @@ -127,11 +126,6 @@ def convert_weights(input_dir_or_checkpoint: str, output_file: str) -> None:
sd = sd["model"]
print("Converting checkpoint...")
sd = phi_4_tune_to_meta(sd)
else:
print("Loading checkpoint from file...")
sd = torch.load(input_dir_or_checkpoint, map_location="cpu", weights_only=True)
print("Converting checkpoint...")
sd = phi_4_hf_to_meta(sd)

print("Saving checkpoint...")
torch.save(sd, output_file)
Expand Down
36 changes: 26 additions & 10 deletions examples/models/qwen3/convert_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Dict

import torch
from executorch.examples.models.checkpoint import load_checkpoint_from_pytorch_model
from safetensors.torch import load_file

from torchtune.models.convert_weights import get_mapped_key
Expand Down Expand Up @@ -80,19 +81,34 @@ def load_checkpoint_from_safetensors(input_dir: str) -> Dict:
tensor = shard_to_weights[shard][weight_name]
merged_state_dict[weight_name] = tensor
return merged_state_dict
else:
# Single checkpoint.
state_dict = load_file(os.path.join(input_dir, "model.safetensors"))
return state_dict

# Single checkpoint.
model_path = os.path.join(input_dir, "model.safetensors")
if os.path.exists(model_path):
return load_file(os.path.join(input_dir, "model.safetensors"))

raise FileNotFoundError(f"Could not find safetensors checkpoint in {input_dir}")


def load_checkpoint(input_dir: str) -> Dict:
pytorch_path = os.path.join(input_dir, "pytorch_model.bin")
if os.path.exists(pytorch_path):
print("Loading checkpoint from PyTorch .bin file")
return torch.load(pytorch_path, map_location="cpu", weights_only=True)
print("Loading checkpoint from safetensors directory")
return load_checkpoint_from_safetensors(input_dir)
try:
print("Loading checkpoint from pytorch_model directory")
state_dict = load_checkpoint_from_pytorch_model(input_dir)
return state_dict
except FileNotFoundError:
print(
"Could not find pytorch_model checkpoints in directory, trying safetensors"
)
pass

try:
print("Loading checkpoint from safetensors directory")
state_dict = load_checkpoint_from_safetensors(input_dir)
return state_dict
except FileNotFoundError:
pass

raise FileNotFoundError(f"Could not find checkpoint in {input_dir}")


def convert_weights(input_dir: str, output_file: str) -> None:
Expand Down
Loading