Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 42 additions & 11 deletions examples/models/phi_4_mini/convert_weights.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import json
import os
from typing import Dict

Expand Down Expand Up @@ -87,10 +88,8 @@ def phi_4_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.T
Convert a state dict from torchtune's format to Meta's format. This function
doesn't handle any sharding or splitting of state dicts. It follows the
state_dict IN -> state_dict OUT pattern.

Args:
state_dict (Dict[str, torch.Tensor]): State dict in torchtune's format.

Returns:
Dict[str, torch.Tensor]: State dict in Meta's format.
"""
Expand All @@ -105,14 +104,51 @@ def phi_4_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.T
converted_state_dict["output.weight"] = converted_state_dict[
"tok_embeddings.weight"
]

return converted_state_dict


def load_checkpoint_from_pytorch_model(input_dir: str) -> Dict:
index_path = os.path.join(input_dir, "pytorch_model.bin.index.json")
if os.path.exists(index_path):
# Sharded checkpoint.
with open(index_path, "r") as f:
index = json.load(f)
weight_map = index["weight_map"]
checkpoint_shards = sorted(set(weight_map.values()))

# Load all the shards into memory
shard_to_weights = {}
for shard in checkpoint_shards:
shard_to_weights[shard] = torch.load(
os.path.join(input_dir, shard),
weights_only=True,
map_location=torch.device("cpu"),
)

# Merge tensors into consolidated state dict.
merged_state_dict = {}
for weight_name, shard in weight_map.items():
tensor = shard_to_weights[shard][weight_name]
merged_state_dict[weight_name] = tensor
return merged_state_dict

# Single checkpoint
model_path = os.path.join(input_dir, "pytorch_model.bin")
if os.path.exists(model_path):
state_dict = torch.load(
model_path, weights_only=True, map_location=torch.device("cpu")
)
return state_dict

raise FileNotFoundError(f"Could not find pytorch_model checkpoint in {input_dir}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should you make this a utilty function to deduplicate logic?



def convert_weights(input_dir_or_checkpoint: str, output_file: str) -> None:
# If input_dir_or_checkpoint is a directory downloaded from HF, FullModelHFCheckpointer is used to extract the state dict
# If input_dir_or_checkpoint is a checkpoint (from eager model model), it is loaded directly
if os.path.isdir(input_dir_or_checkpoint):
try:
sd = load_checkpoint_from_pytorch_model(input_dir_or_checkpoint)
print("Converting checkpoint...")
sd = phi_4_hf_to_meta(sd)
except FileNotFoundError:
checkpointer = FullModelHFCheckpointer(
checkpoint_dir=input_dir_or_checkpoint,
checkpoint_files=[
Expand All @@ -127,11 +163,6 @@ def convert_weights(input_dir_or_checkpoint: str, output_file: str) -> None:
sd = sd["model"]
print("Converting checkpoint...")
sd = phi_4_tune_to_meta(sd)
else:
print("Loading checkpoint from file...")
sd = torch.load(input_dir_or_checkpoint, map_location="cpu", weights_only=True)
print("Converting checkpoint...")
sd = phi_4_hf_to_meta(sd)

print("Saving checkpoint...")
torch.save(sd, output_file)
Expand Down
69 changes: 60 additions & 9 deletions examples/models/qwen3/convert_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,19 +80,70 @@ def load_checkpoint_from_safetensors(input_dir: str) -> Dict:
tensor = shard_to_weights[shard][weight_name]
merged_state_dict[weight_name] = tensor
return merged_state_dict
else:
# Single checkpoint.
state_dict = load_file(os.path.join(input_dir, "model.safetensors"))

# Single checkpoint.
model_path = os.path.join(input_dir, "model.safetensors")
if os.path.exists(model_path):
return load_file(os.path.join(input_dir, "model.safetensors"))

raise FileNotFoundError(f"Could not find safetensors checkpoint in {input_dir}")


def load_checkpoint_from_pytorch_model(input_dir: str) -> Dict:
index_path = os.path.join(input_dir, "pytorch_model.bin.index.json")
if os.path.exists(index_path):
# Sharded checkpoint.
with open(index_path, "r") as f:
index = json.load(f)
weight_map = index["weight_map"]
checkpoint_shards = sorted(set(weight_map.values()))

# Load all the shards into memory
shard_to_weights = {}
for shard in checkpoint_shards:
shard_to_weights[shard] = torch.load(
os.path.join(input_dir, shard),
weights_only=True,
map_location=torch.device("cpu"),
)

# Merge tensors into consolidated state dict.
merged_state_dict = {}
for weight_name, shard in weight_map.items():
tensor = shard_to_weights[shard][weight_name]
merged_state_dict[weight_name] = tensor
return merged_state_dict

# Single checkpoint
model_path = os.path.join(input_dir, "pytorch_model.bin")
if os.path.exists(model_path):
state_dict = torch.load(
model_path, weights_only=True, map_location=torch.device("cpu")
)
return state_dict

raise FileNotFoundError(f"Could not find pytorch_model checkpoint in {input_dir}")


def load_checkpoint(input_dir: str) -> Dict:
pytorch_path = os.path.join(input_dir, "pytorch_model.bin")
if os.path.exists(pytorch_path):
print("Loading checkpoint from PyTorch .bin file")
return torch.load(pytorch_path, map_location="cpu", weights_only=True)
print("Loading checkpoint from safetensors directory")
return load_checkpoint_from_safetensors(input_dir)
try:
print("Loading checkpoint from pytorch_model directory")
state_dict = load_checkpoint_from_pytorch_model(input_dir)
return state_dict
except FileNotFoundError:
print(
"Could not find pytorch_model checkpoints in directory, trying safetensors"
)
pass

try:
print("Loading checkpoint from safetensors directory")
state_dict = load_checkpoint_from_safetensors(input_dir)
return state_dict
except FileNotFoundError:
pass

raise FileNotFoundError(f"Could not find checkpoint in {input_dir}")


def convert_weights(input_dir: str, output_file: str) -> None:
Expand Down
Loading