Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
378 changes: 378 additions & 0 deletions skyrl-train/scripts/convert_fsdp_to_hf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,378 @@
"""
Helper Script to convert FSDP shards to safetensor model files, compatible with Huggingface API

The main purpose is to be able to enable users who choose not to enable HF model saves during training, such as enable the `hf_save_interval` parameter, to
also be able to benefit from a way to create a HF safetensors model.

For FSDP2 model shards, the output directory will be created with the following structure:
.
├── added_tokens.json
├── chat_template.jinja (optional: this file is for chat specific tasks)
├── config.json
├── generation_config.json (optional: default decoding parameters)
├── merges.txt
├── model.safetensors
├── special_tokens_map.json
├── tokenizer.json
├── tokenizer_config.json
└── vocab.json

Example usage:
uv run --isolated --frozen --extra vllm scripts/convert_fsdp_to_hf.py --ckpt-dir /home/ray/ckpts/gsm8k_0.5B_ckpt/global_step_10 --out-dir /home/ray/hf/glob_step_10
"""

import argparse
import re
import sys
from pathlib import Path
import shutil
from typing import Dict, List, Optional

from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoModel
import torch


def find_policy_dir(chkpt_dir: Path) -> Path:
"""
Return a Path object to the policy directory
- Path to the policy directory
"""
pol = chkpt_dir / "policy"
if not pol.exists():
print(f"[error] Expected 'policy/' under {chkpt_dir}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The script uses print for status and error messages. Using the standard logging module is a better practice for scripts like this. It provides more flexibility, such as controlling log levels (INFO, WARN, ERROR), formatting, and directing output to different destinations (e.g., file, stderr) without changing the code.

raise FileNotFoundError(f"Expected 'policy/' under {chkpt_dir}")
return pol


def get_model_shards(policy_dir: Path) -> List[Path]:
"""
Return a list of model Path objects
- List[Path] of the model shards (the model*.pt files)
"""
shards_paths = sorted(policy_dir.glob("model_world_size_*_rank_*.pt"))
if not shards_paths:
shards_paths = sorted(policy_dir.glob("model*.pt"))
if not shards_paths:
print(f"[error] No model shards found under {policy_dir}")
raise FileNotFoundError(f"No model shards found under {policy_dir}")
return shards_paths


# Not used at the moment
def normalize_key(k: str) -> str:
"""
Return a normalized key to ensure consistency across checkpointing frameworks
Example - Attention layer training:
"module.encoder.layer.0.attention.self.query.weight"
"model.module.encoder.layer.0.attention.self.query.weight"
"encoder.layer.0.attention.self.query.weight"
These 3 should refer to the same thing, so they should be normalized.

Function takes string and removes all possible prefixes.
"""
k = re.sub(r"^(module|model)\.", "", k)
k = k.replace("_fsdp_wrapped_module.", "")
return k
Comment on lines +61 to +75
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This function normalize_key is defined but never used in the script. Unused code can be confusing and adds maintenance overhead. It should either be used to normalize the state dictionary keys in merge_shards (by setting nk = normalize_key(k)) or removed if it's not necessary.



def load_single_shard(path: Path) -> Dict[str, torch.Tensor]:
"""
Load a single model shard and return a dictionary of tensors
- Dict[str, torch.Tensor]
"""
obj = torch.load(path, map_location="cpu", weights_only=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using torch.load with weights_only=False can be a security risk as it allows unpickling of arbitrary Python objects, which can lead to code execution. While you are likely loading checkpoints you created yourself, it's a best practice to use weights_only=True whenever possible to mitigate this risk. If the checkpoint files only contain tensors and dictionaries of tensors, this change should work without issues.

Suggested change
obj = torch.load(path, map_location="cpu", weights_only=False)
obj = torch.load(path, map_location="cpu", weights_only=True)

for key in ("state_dict", "model", "module"):
if isinstance(obj, dict) and key in obj and isinstance(obj[key], dict):
obj = obj[key]
if not isinstance(obj, dict):
raise ValueError(f"Unexpected checkpoint format at {path} (type={type(obj)})")
return {k: v for k, v in obj.items() if isinstance(v, torch.Tensor)}


def is_vocab_key(k: str) -> bool:
"""
Check for the keys in the state_dict that we want to merge shards for, in a MultiGPU setup

These keys are the keys in the state_dict that should be merged into a single tensor (from all the shards)
"""
suffixes = (
"embed_tokens.weight",
"lm_head.weight",
"word_embeddings.weight",
"wte.weight",
)
return any(k.endswith(sfx) for sfx in suffixes)


def merge_two_shards(
existing: torch.Tensor, new_shard: torch.Tensor, key: str, merge_type: Optional[str] = None
) -> torch.Tensor:
"""
Merge two tensor shards into a single tensor, containing both the existing and new shards
Merge tensors with key in the vocabulary as the default case, always.
If key (tensor name) is not in the vocabulary then we merge based on the `merge_type` parameter
- (str=default) Default: On the first dimension of the tensor
- (str=vocab_parallel) Vocab-parallel (embeddings/lm_head): cat_dim = 0 (batch / words)
- (str=col_parallel) Column-parallel linear (weights split output features): cat_dim = 1 (output features)
- (str=row_parallel) Row-parallel linear ( weights split input features): cat_dim = 2 (input features)
- (str=other) Replicated Parameters (LayerNorm, RMSNorm, rotary, etc): sum existing and new_shard

If the merge_type is equal to None, then we use heuristic fallback.

Heuristic Fallback:
* If only dim0 differs -> cat dim=0
* If only dim1 differs -> cat dim=1
* If only dim2 differs -> cat dim=2
* If shapes equal -> add (sum) (useful for row-parallel biases)
"""
if is_vocab_key(key) or merge_type == "default":
return torch.cat([existing, new_shard], dim=0)

if merge_type == "vocab_parallel":
return torch.cat([existing, new_shard], dim=0)
elif merge_type == "col_parallel":
if existing.ndim >= 2 and new_shard.ndim >= 2:
return torch.cat([existing, new_shard], dim=1)
else:
raise ValueError(
f"existing.ndim={existing.ndim}, new_shard.ndim={new_shard.ndim}, cannot do col_parallel merging because at least 2 dimensions of both tensors are required"
)
elif merge_type == "row_parallel":
if existing.ndim >= 3 and new_shard.ndim >= 3:
return torch.cat([existing, new_shard], dim=2)
else:
raise ValueError(
f"existing.ndim={existing.ndim}, new_shard.ndim={new_shard.ndim}, cannot do row_parallel merging because at least 3 dimensions of both tensors are required"
)
return torch.cat([existing, new_shard], dim=2)
elif merge_type == "other":
return existing
else:
## merge_type = None or unknown, then we simply merge by heuristic
if existing.ndim >= 2 and existing.shape[0] != new_shard.shape[0] and existing.shape[1] == new_shard.shape[1]:
# Likely word-parallel linear weight (PyTorch Linear is [out, in])
return torch.cat((existing, new_shard), dim=0)
if existing.ndim >= 2 and existing.shape[0] == new_shard.shape[0] and existing.shape[1] != new_shard.shape[1]:
# Likely col-parallel linear weight
return torch.cat((existing, new_shard), dim=1)
if (
existing.ndim >= 3
and existing.shape[0] == new_shard.shape[0]
and existing.shape[1] == new_shard.shape[1]
and existing.shape[2] != new_shard.shape[2]
):
# Likely row-parallel linear weight
return torch.cat((existing, new_shard), dim=2)
if existing.shape == new_shard.shape:
# Could be row-parallel bias or replicated tensors.
# Try SUM
return existing + new_shard
Comment on lines +166 to +169
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The heuristic for merging tensors of the same shape is to sum them. This is correct for some cases like row-parallel biases, but incorrect for replicated parameters (e.g., LayerNorm weights), which should be identical across shards. Summing replicated parameters will result in incorrect weights (effectively scaling them by the world size), leading to silent model corruption. A safer heuristic would be to distinguish between these cases, for example by checking for 'bias' in the tensor key.

Suggested change
if existing.shape == new_shard.shape:
# Could be row-parallel bias or replicated tensors.
# Try SUM
return existing + new_shard
if existing.shape == new_shard.shape:
# Could be row-parallel bias (needs sum) or a replicated tensor (identical).
# Summing replicated tensors is incorrect. A simple heuristic is to check for 'bias' in the key.
if "bias" in key:
return existing + new_shard
return existing # Assume replicated

raise ValueError(f"Don't know how to merge key '{key}' with shapes {existing.shape} and {new_shard.shape}")


def merge_shards(shards_paths: List[Path]) -> Dict[str, torch.Tensor]:
"""
Merge all model shards into a single dictionary of string-based keys to their corresponding tensors
- Dict[str, torch.Tensor]
"""
merged: Dict[str, torch.Tensor] = {}
for shard in shards_paths:
sd = load_single_shard(shard)
for k, v in sd.items():
nk = k
# if nk in merged:
# if merged[nk].shape != v.shape or merged[nk].dtype != v.dtype:
# print(
# f"[error] Key collision with mismatch for '{nk}' between shards "
# f"(existing {merged[nk].shape}/{merged[nk].dtype} vs {v.shape}/{v.dtype})"
# )
# raise ValueError(
# f"Key collision with mismatch for '{nk}' between shards "
# f"(existing {merged[nk].shape}/{merged[nk].dtype} vs {v.shape}/{v.dtype})"
# )
# else:
# # Merging tensors using merge_two_shards with heuristic fallback
try:
from torch.distributed.tensor import DTensor
except (ImportError, AttributeError):
DTensor = None
try:
from torch.distributed.tensor import ShardedTensor
except (ImportError, AttributeError):
ShardedTensor = None
if DTensor is not None and isinstance(v, DTensor):
v = v.to_local()
if ShardedTensor is not None and isinstance(v, ShardedTensor):
v = v.local_tensor()

if nk in merged:
print(
f"Tensors to be merged: (original) with shape {merged[nk].shape} and (newly added) with shape {v.shape}"
)
merged[nk] = merge_two_shards(merged[nk], v.detach().cpu().contiguous(), key=nk, merge_type="default")
else:
print(f"Current tensor shape: {v.shape}")
merged[nk] = v.detach().cpu().contiguous()
if not merged:
print("[error] No tensors found in shards")
raise RuntimeError("No tensors found in shards")
return merged


def copy_hf_artifacts(policy_dir: Path, out_dir: Path) -> None:
"""
Copy huggingface artifacts from the policy directory to the output directory
- A utility function that copies huggingface artifacts from the policy directory to the output directory
"""
hf_src = policy_dir / "huggingface"
out_dir.mkdir(parents=True, exist_ok=True)
if hf_src.exists():
for p in hf_src.iterdir():
dst = out_dir / p.name
if p.is_file():
shutil.copy2(p, dst)
elif p.is_dir():
if dst.exists():
shutil.rmtree(dst)
shutil.copytree(p, dst)
else:
print("[warn] policy/huggingface not found; you must supply a proper config/tokenizer.", file=sys.stderr)


def _materialize_for_safetensors(state_dict) -> Dict[str, torch.Tensor]:
"""
Materialize the state dict for safetensors
- A utility function that materializes the state dict for safetensors
Essentially converts all torch tensors to local tensors so they can actually be saved.
1) DTensor to local tensor
2) ShardedTensor to local tensor
We do not save meta tensors because they have no data and are not materializable.

Then after that, convert these local tensors to cpu tensors, and create a new dictionary of keys -> Tensors.
"""
import torch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The torch module is imported at the top of the file. This import torch statement inside the _materialize_for_safetensors function is redundant and can be removed.


new_sd = {}
for k, v in state_dict.items():
if not isinstance(v, torch.Tensor):
new_sd[k] = v
continue
t = v
if type(t).__name__ == "DTensor" and hasattr(t, "to_local"):
t = t.to_local()
if type(t).__name__ == "ShardedTensor" and hasattr(t, "local_tensor"):
t = t.local_tensor()
if getattr(t, "is_meta", False):
raise RuntimeError(f"Tensor {k} is on meta device; load the real weights before saving.")
if type(t).__name__ == "FakeTensor":
raise RuntimeError(f"Tensor {k} is a FakeTensor; disable fake tensor mode for save")
if t.device.type != "cpu":
t = t.to("cpu", non_blocking=False)
t = t.detach()
if not t.is_contiguous():
t = t.contiguous()
new_sd[k] = t
return new_sd


def guess_hf_class(cfg: AutoConfig):
"""
Tries to find a reasonable HF class from config
Falls back to the AutoModel architecture if an LM head can't be detected
"""
if getattr(cfg, "is_encoder_decoder", False):
return AutoModelForSeq2SeqLM
archs = getattr(cfg, "architectures", []) or []
if any(a.endswith("ForCausalLM") for a in archs):
return AutoModelForCausalLM
decoders = {"gpt2", "gpt_bigcode", "llama", "mistral", "qwen", "qwen2", "internlm", "mpt", "phi", "falcon"}
if getattr(cfg, "model_type", "") in decoders:
return AutoModelForCausalLM
return AutoModel


def validate_load(out_dir: Path):
"""
Optional: sanity-load with HF to ensure the saved safetensors is consumable
Loads on the CPU to avoid device / dtype quirk (this may be a problem for loading on GPU which could cause data loading issues)
"""
try:
cfg = AutoConfig.from_pretrained(out_dir, local_files_only=True, trust_remote_code=True)
HFClass = guess_hf_class(cfg)
_ = HFClass.from_pretrained(
out_dir, local_files_only=True, device_map=None, dtype="auto", trust_remote_code=True
)
print("[validate] HF Load OK")
except Exception as e:
print(f"[validate][error] HF Load failed: {e} ", e)
raise RuntimeError("HF Load failed")


def _untie_shared_tensors(sd) -> Dict[str, torch.Tensor]:
"""
Untie shared tensors
- A utility function that unties shared tensors
Some tensors may be shared by different keys in that the tensors they point to have the same data pointer.
This function takes a state dict and returns a new state dict where the shared tensors have been untied.
This is done by creating a new tensor (clone it) for each shared tensor.
This allows each key to refer to a UNIQUE tensor
"""
seen = {}
for k, v in list(sd.items()):
if not isinstance(v, torch.Tensor):
continue
try:
ptr = v.storage().data_ptr()
except Exception:
continue
if ptr in seen:
sd[k] = v.clone()
else:
seen[ptr] = (k, v)
return sd


def main():
ap = argparse.ArgumentParser(description="Convert FSDP checkpoint shards to a HuggingFace safetensors model.")
ap.add_argument(
"--ckpt-dir", type=str, required=True, help="Path to the checkpoint directory, containing trainer_state.pt"
)
ap.add_argument("--out-dir", type=str, required=True, help="Output for HF model folder")
ap.add_argument(
"--validate-load", action="store_true", help="Try loading with the Transformers Module after saving"
)
args = ap.parse_args()
ckpt_dir = Path(args.ckpt_dir).resolve()
output_dir = Path(args.out_dir).resolve()
policy_dir = find_policy_dir(ckpt_dir)
shards_paths = get_model_shards(policy_dir)
print(f"[info] Found {len(shards_paths)} model shard(s). ")
for s in shards_paths:
print(f"[info] - {s}")

print("[info] Merging shards...")
state_dict = merge_shards(shards_paths)
print(f"[info] Merged {len(state_dict)} tensors.")

copy_hf_artifacts(policy_dir, output_dir)

clean_sd = _materialize_for_safetensors(state_dict)
clean_sd = _untie_shared_tensors(clean_sd)

# save_file(clean_sd, str(weights_path))
# print(f"[success] Saved weights to {weights_path}")

cfg = AutoConfig.from_pretrained(output_dir, local_files_only=True, trust_remote_code=True)
HFClass = guess_hf_class(cfg)
hf_model = HFClass.from_config(cfg)
hf_model.save_pretrained(
save_directory=output_dir,
state_dict=clean_sd,
)

if args.validate_load:
validate_load(output_dir)


if __name__ == "__main__":
main()