-
Notifications
You must be signed in to change notification settings - Fork 220
Added conversion pipeline of FSDP2 model checkpoint shards to Huggingface .safetensors #618
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a script for converting FSDP model checkpoints into the Hugging Face .safetensors format. The script is well-structured, handling various tensor types and including a validation step. However, I've identified critical flaws in the tensor merging logic that could lead to silent corruption of the model weights. The current implementation uses a hardcoded merge strategy that is not generally applicable, and the fallback heuristic contains a bug. My review includes suggestions to address these critical issues and improve the script's robustness. I've also included feedback on security best practices, code clarity, and removing unused code.
| if existing.shape == new_shard.shape: | ||
| # Could be row-parallel bias or replicated tensors. | ||
| # Try SUM | ||
| return existing + new_shard |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The heuristic for merging tensors of the same shape is to sum them. This is correct for some cases like row-parallel biases, but incorrect for replicated parameters (e.g., LayerNorm weights), which should be identical across shards. Summing replicated parameters will result in incorrect weights (effectively scaling them by the world size), leading to silent model corruption. A safer heuristic would be to distinguish between these cases, for example by checking for 'bias' in the tensor key.
| if existing.shape == new_shard.shape: | |
| # Could be row-parallel bias or replicated tensors. | |
| # Try SUM | |
| return existing + new_shard | |
| if existing.shape == new_shard.shape: | |
| # Could be row-parallel bias (needs sum) or a replicated tensor (identical). | |
| # Summing replicated tensors is incorrect. A simple heuristic is to check for 'bias' in the key. | |
| if "bias" in key: | |
| return existing + new_shard | |
| return existing # Assume replicated |
| # Not used at the moment | ||
| def normalize_key(k: str) -> str: | ||
| """ | ||
| Return a normalized key to ensure consistency across checkpointing frameworks | ||
| Example - Attention layer training: | ||
| "module.encoder.layer.0.attention.self.query.weight" | ||
| "model.module.encoder.layer.0.attention.self.query.weight" | ||
| "encoder.layer.0.attention.self.query.weight" | ||
| These 3 should refer to the same thing, so they should be normalized. | ||
|
|
||
| Function takes string and removes all possible prefixes. | ||
| """ | ||
| k = re.sub(r"^(module|model)\.", "", k) | ||
| k = k.replace("_fsdp_wrapped_module.", "") | ||
| return k |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| Load a single model shard and return a dictionary of tensors | ||
| - Dict[str, torch.Tensor] | ||
| """ | ||
| obj = torch.load(path, map_location="cpu", weights_only=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using torch.load with weights_only=False can be a security risk as it allows unpickling of arbitrary Python objects, which can lead to code execution. While you are likely loading checkpoints you created yourself, it's a best practice to use weights_only=True whenever possible to mitigate this risk. If the checkpoint files only contain tensors and dictionaries of tensors, this change should work without issues.
| obj = torch.load(path, map_location="cpu", weights_only=False) | |
| obj = torch.load(path, map_location="cpu", weights_only=True) |
| """ | ||
| pol = chkpt_dir / "policy" | ||
| if not pol.exists(): | ||
| print(f"[error] Expected 'policy/' under {chkpt_dir}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The script uses print for status and error messages. Using the standard logging module is a better practice for scripts like this. It provides more flexibility, such as controlling log levels (INFO, WARN, ERROR), formatting, and directing output to different destinations (e.g., file, stderr) without changing the code.
|
|
||
| Then after that, convert these local tensors to cpu tensors, and create a new dictionary of keys -> Tensors. | ||
| """ | ||
| import torch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wrote script for converting FSDP2 model shard into huggingface .safetensors.
Filepath:
~/skyrl-train/scripts/convert_fsdp_to_hf.py
Tested on
trained with run_gsm8k.sh with the fsdp2 backend.