Skip to content

Conversation

@zhenga1
Copy link
Contributor

@zhenga1 zhenga1 commented Nov 2, 2025

Wrote script for converting FSDP2 model shard into huggingface .safetensors.

Filepath:
~/skyrl-train/scripts/convert_fsdp_to_hf.py

Tested on

  • Qwen0.5B models
  • Qwen4B models

trained with run_gsm8k.sh with the fsdp2 backend.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a script for converting FSDP model checkpoints into the Hugging Face .safetensors format. The script is well-structured, handling various tensor types and including a validation step. However, I've identified critical flaws in the tensor merging logic that could lead to silent corruption of the model weights. The current implementation uses a hardcoded merge strategy that is not generally applicable, and the fallback heuristic contains a bug. My review includes suggestions to address these critical issues and improve the script's robustness. I've also included feedback on security best practices, code clarity, and removing unused code.

Comment on lines +166 to +169
if existing.shape == new_shard.shape:
# Could be row-parallel bias or replicated tensors.
# Try SUM
return existing + new_shard
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The heuristic for merging tensors of the same shape is to sum them. This is correct for some cases like row-parallel biases, but incorrect for replicated parameters (e.g., LayerNorm weights), which should be identical across shards. Summing replicated parameters will result in incorrect weights (effectively scaling them by the world size), leading to silent model corruption. A safer heuristic would be to distinguish between these cases, for example by checking for 'bias' in the tensor key.

Suggested change
if existing.shape == new_shard.shape:
# Could be row-parallel bias or replicated tensors.
# Try SUM
return existing + new_shard
if existing.shape == new_shard.shape:
# Could be row-parallel bias (needs sum) or a replicated tensor (identical).
# Summing replicated tensors is incorrect. A simple heuristic is to check for 'bias' in the key.
if "bias" in key:
return existing + new_shard
return existing # Assume replicated

Comment on lines +61 to +75
# Not used at the moment
def normalize_key(k: str) -> str:
"""
Return a normalized key to ensure consistency across checkpointing frameworks
Example - Attention layer training:
"module.encoder.layer.0.attention.self.query.weight"
"model.module.encoder.layer.0.attention.self.query.weight"
"encoder.layer.0.attention.self.query.weight"
These 3 should refer to the same thing, so they should be normalized.

Function takes string and removes all possible prefixes.
"""
k = re.sub(r"^(module|model)\.", "", k)
k = k.replace("_fsdp_wrapped_module.", "")
return k
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This function normalize_key is defined but never used in the script. Unused code can be confusing and adds maintenance overhead. It should either be used to normalize the state dictionary keys in merge_shards (by setting nk = normalize_key(k)) or removed if it's not necessary.

Load a single model shard and return a dictionary of tensors
- Dict[str, torch.Tensor]
"""
obj = torch.load(path, map_location="cpu", weights_only=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using torch.load with weights_only=False can be a security risk as it allows unpickling of arbitrary Python objects, which can lead to code execution. While you are likely loading checkpoints you created yourself, it's a best practice to use weights_only=True whenever possible to mitigate this risk. If the checkpoint files only contain tensors and dictionaries of tensors, this change should work without issues.

Suggested change
obj = torch.load(path, map_location="cpu", weights_only=False)
obj = torch.load(path, map_location="cpu", weights_only=True)

"""
pol = chkpt_dir / "policy"
if not pol.exists():
print(f"[error] Expected 'policy/' under {chkpt_dir}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The script uses print for status and error messages. Using the standard logging module is a better practice for scripts like this. It provides more flexibility, such as controlling log levels (INFO, WARN, ERROR), formatting, and directing output to different destinations (e.g., file, stderr) without changing the code.


Then after that, convert these local tensors to cpu tensors, and create a new dictionary of keys -> Tensors.
"""
import torch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The torch module is imported at the top of the file. This import torch statement inside the _materialize_for_safetensors function is redundant and can be removed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants