-
Notifications
You must be signed in to change notification settings - Fork 222
Added conversion pipeline of FSDP2 model checkpoint shards to Huggingface .safetensors #618
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,378 @@ | ||||||||||||||||||||||
| """ | ||||||||||||||||||||||
| Helper Script to convert FSDP shards to safetensor model files, compatible with Huggingface API | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| The main purpose is to be able to enable users who choose not to enable HF model saves during training, such as enable the `hf_save_interval` parameter, to | ||||||||||||||||||||||
| also be able to benefit from a way to create a HF safetensors model. | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| For FSDP2 model shards, the output directory will be created with the following structure: | ||||||||||||||||||||||
| . | ||||||||||||||||||||||
| ├── added_tokens.json | ||||||||||||||||||||||
| ├── chat_template.jinja (optional: this file is for chat specific tasks) | ||||||||||||||||||||||
| ├── config.json | ||||||||||||||||||||||
| ├── generation_config.json (optional: default decoding parameters) | ||||||||||||||||||||||
| ├── merges.txt | ||||||||||||||||||||||
| ├── model.safetensors | ||||||||||||||||||||||
| ├── special_tokens_map.json | ||||||||||||||||||||||
| ├── tokenizer.json | ||||||||||||||||||||||
| ├── tokenizer_config.json | ||||||||||||||||||||||
| └── vocab.json | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| Example usage: | ||||||||||||||||||||||
| uv run --isolated --frozen --extra vllm scripts/convert_fsdp_to_hf.py --ckpt-dir /home/ray/ckpts/gsm8k_0.5B_ckpt/global_step_10 --out-dir /home/ray/hf/glob_step_10 | ||||||||||||||||||||||
| """ | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| import argparse | ||||||||||||||||||||||
| import re | ||||||||||||||||||||||
| import sys | ||||||||||||||||||||||
| from pathlib import Path | ||||||||||||||||||||||
| import shutil | ||||||||||||||||||||||
| from typing import Dict, List, Optional | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoModel | ||||||||||||||||||||||
| import torch | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def find_policy_dir(chkpt_dir: Path) -> Path: | ||||||||||||||||||||||
| """ | ||||||||||||||||||||||
| Return a Path object to the policy directory | ||||||||||||||||||||||
| - Path to the policy directory | ||||||||||||||||||||||
| """ | ||||||||||||||||||||||
| pol = chkpt_dir / "policy" | ||||||||||||||||||||||
| if not pol.exists(): | ||||||||||||||||||||||
| print(f"[error] Expected 'policy/' under {chkpt_dir}") | ||||||||||||||||||||||
| raise FileNotFoundError(f"Expected 'policy/' under {chkpt_dir}") | ||||||||||||||||||||||
| return pol | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def get_model_shards(policy_dir: Path) -> List[Path]: | ||||||||||||||||||||||
| """ | ||||||||||||||||||||||
| Return a list of model Path objects | ||||||||||||||||||||||
| - List[Path] of the model shards (the model*.pt files) | ||||||||||||||||||||||
| """ | ||||||||||||||||||||||
| shards_paths = sorted(policy_dir.glob("model_world_size_*_rank_*.pt")) | ||||||||||||||||||||||
| if not shards_paths: | ||||||||||||||||||||||
| shards_paths = sorted(policy_dir.glob("model*.pt")) | ||||||||||||||||||||||
| if not shards_paths: | ||||||||||||||||||||||
| print(f"[error] No model shards found under {policy_dir}") | ||||||||||||||||||||||
| raise FileNotFoundError(f"No model shards found under {policy_dir}") | ||||||||||||||||||||||
| return shards_paths | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| # Not used at the moment | ||||||||||||||||||||||
| def normalize_key(k: str) -> str: | ||||||||||||||||||||||
| """ | ||||||||||||||||||||||
| Return a normalized key to ensure consistency across checkpointing frameworks | ||||||||||||||||||||||
| Example - Attention layer training: | ||||||||||||||||||||||
| "module.encoder.layer.0.attention.self.query.weight" | ||||||||||||||||||||||
| "model.module.encoder.layer.0.attention.self.query.weight" | ||||||||||||||||||||||
| "encoder.layer.0.attention.self.query.weight" | ||||||||||||||||||||||
| These 3 should refer to the same thing, so they should be normalized. | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| Function takes string and removes all possible prefixes. | ||||||||||||||||||||||
| """ | ||||||||||||||||||||||
| k = re.sub(r"^(module|model)\.", "", k) | ||||||||||||||||||||||
| k = k.replace("_fsdp_wrapped_module.", "") | ||||||||||||||||||||||
| return k | ||||||||||||||||||||||
|
Comment on lines
+61
to
+75
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def load_single_shard(path: Path) -> Dict[str, torch.Tensor]: | ||||||||||||||||||||||
| """ | ||||||||||||||||||||||
| Load a single model shard and return a dictionary of tensors | ||||||||||||||||||||||
| - Dict[str, torch.Tensor] | ||||||||||||||||||||||
| """ | ||||||||||||||||||||||
| obj = torch.load(path, map_location="cpu", weights_only=False) | ||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using
Suggested change
|
||||||||||||||||||||||
| for key in ("state_dict", "model", "module"): | ||||||||||||||||||||||
| if isinstance(obj, dict) and key in obj and isinstance(obj[key], dict): | ||||||||||||||||||||||
| obj = obj[key] | ||||||||||||||||||||||
| if not isinstance(obj, dict): | ||||||||||||||||||||||
| raise ValueError(f"Unexpected checkpoint format at {path} (type={type(obj)})") | ||||||||||||||||||||||
| return {k: v for k, v in obj.items() if isinstance(v, torch.Tensor)} | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def is_vocab_key(k: str) -> bool: | ||||||||||||||||||||||
| """ | ||||||||||||||||||||||
| Check for the keys in the state_dict that we want to merge shards for, in a MultiGPU setup | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| These keys are the keys in the state_dict that should be merged into a single tensor (from all the shards) | ||||||||||||||||||||||
| """ | ||||||||||||||||||||||
| suffixes = ( | ||||||||||||||||||||||
| "embed_tokens.weight", | ||||||||||||||||||||||
| "lm_head.weight", | ||||||||||||||||||||||
| "word_embeddings.weight", | ||||||||||||||||||||||
| "wte.weight", | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| return any(k.endswith(sfx) for sfx in suffixes) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def merge_two_shards( | ||||||||||||||||||||||
| existing: torch.Tensor, new_shard: torch.Tensor, key: str, merge_type: Optional[str] = None | ||||||||||||||||||||||
| ) -> torch.Tensor: | ||||||||||||||||||||||
| """ | ||||||||||||||||||||||
| Merge two tensor shards into a single tensor, containing both the existing and new shards | ||||||||||||||||||||||
| Merge tensors with key in the vocabulary as the default case, always. | ||||||||||||||||||||||
| If key (tensor name) is not in the vocabulary then we merge based on the `merge_type` parameter | ||||||||||||||||||||||
| - (str=default) Default: On the first dimension of the tensor | ||||||||||||||||||||||
| - (str=vocab_parallel) Vocab-parallel (embeddings/lm_head): cat_dim = 0 (batch / words) | ||||||||||||||||||||||
| - (str=col_parallel) Column-parallel linear (weights split output features): cat_dim = 1 (output features) | ||||||||||||||||||||||
| - (str=row_parallel) Row-parallel linear ( weights split input features): cat_dim = 2 (input features) | ||||||||||||||||||||||
| - (str=other) Replicated Parameters (LayerNorm, RMSNorm, rotary, etc): sum existing and new_shard | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| If the merge_type is equal to None, then we use heuristic fallback. | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| Heuristic Fallback: | ||||||||||||||||||||||
| * If only dim0 differs -> cat dim=0 | ||||||||||||||||||||||
| * If only dim1 differs -> cat dim=1 | ||||||||||||||||||||||
| * If only dim2 differs -> cat dim=2 | ||||||||||||||||||||||
| * If shapes equal -> add (sum) (useful for row-parallel biases) | ||||||||||||||||||||||
| """ | ||||||||||||||||||||||
| if is_vocab_key(key) or merge_type == "default": | ||||||||||||||||||||||
| return torch.cat([existing, new_shard], dim=0) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| if merge_type == "vocab_parallel": | ||||||||||||||||||||||
| return torch.cat([existing, new_shard], dim=0) | ||||||||||||||||||||||
| elif merge_type == "col_parallel": | ||||||||||||||||||||||
| if existing.ndim >= 2 and new_shard.ndim >= 2: | ||||||||||||||||||||||
| return torch.cat([existing, new_shard], dim=1) | ||||||||||||||||||||||
| else: | ||||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||||
| f"existing.ndim={existing.ndim}, new_shard.ndim={new_shard.ndim}, cannot do col_parallel merging because at least 2 dimensions of both tensors are required" | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| elif merge_type == "row_parallel": | ||||||||||||||||||||||
| if existing.ndim >= 3 and new_shard.ndim >= 3: | ||||||||||||||||||||||
| return torch.cat([existing, new_shard], dim=2) | ||||||||||||||||||||||
| else: | ||||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||||
| f"existing.ndim={existing.ndim}, new_shard.ndim={new_shard.ndim}, cannot do row_parallel merging because at least 3 dimensions of both tensors are required" | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| return torch.cat([existing, new_shard], dim=2) | ||||||||||||||||||||||
| elif merge_type == "other": | ||||||||||||||||||||||
| return existing | ||||||||||||||||||||||
| else: | ||||||||||||||||||||||
| ## merge_type = None or unknown, then we simply merge by heuristic | ||||||||||||||||||||||
| if existing.ndim >= 2 and existing.shape[0] != new_shard.shape[0] and existing.shape[1] == new_shard.shape[1]: | ||||||||||||||||||||||
| # Likely word-parallel linear weight (PyTorch Linear is [out, in]) | ||||||||||||||||||||||
| return torch.cat((existing, new_shard), dim=0) | ||||||||||||||||||||||
| if existing.ndim >= 2 and existing.shape[0] == new_shard.shape[0] and existing.shape[1] != new_shard.shape[1]: | ||||||||||||||||||||||
| # Likely col-parallel linear weight | ||||||||||||||||||||||
| return torch.cat((existing, new_shard), dim=1) | ||||||||||||||||||||||
| if ( | ||||||||||||||||||||||
| existing.ndim >= 3 | ||||||||||||||||||||||
| and existing.shape[0] == new_shard.shape[0] | ||||||||||||||||||||||
| and existing.shape[1] == new_shard.shape[1] | ||||||||||||||||||||||
| and existing.shape[2] != new_shard.shape[2] | ||||||||||||||||||||||
| ): | ||||||||||||||||||||||
| # Likely row-parallel linear weight | ||||||||||||||||||||||
| return torch.cat((existing, new_shard), dim=2) | ||||||||||||||||||||||
| if existing.shape == new_shard.shape: | ||||||||||||||||||||||
| # Could be row-parallel bias or replicated tensors. | ||||||||||||||||||||||
| # Try SUM | ||||||||||||||||||||||
| return existing + new_shard | ||||||||||||||||||||||
|
Comment on lines
+166
to
+169
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The heuristic for merging tensors of the same shape is to sum them. This is correct for some cases like row-parallel biases, but incorrect for replicated parameters (e.g., LayerNorm weights), which should be identical across shards. Summing replicated parameters will result in incorrect weights (effectively scaling them by the world size), leading to silent model corruption. A safer heuristic would be to distinguish between these cases, for example by checking for 'bias' in the tensor key.
Suggested change
|
||||||||||||||||||||||
| raise ValueError(f"Don't know how to merge key '{key}' with shapes {existing.shape} and {new_shard.shape}") | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def merge_shards(shards_paths: List[Path]) -> Dict[str, torch.Tensor]: | ||||||||||||||||||||||
| """ | ||||||||||||||||||||||
| Merge all model shards into a single dictionary of string-based keys to their corresponding tensors | ||||||||||||||||||||||
| - Dict[str, torch.Tensor] | ||||||||||||||||||||||
| """ | ||||||||||||||||||||||
| merged: Dict[str, torch.Tensor] = {} | ||||||||||||||||||||||
| for shard in shards_paths: | ||||||||||||||||||||||
| sd = load_single_shard(shard) | ||||||||||||||||||||||
| for k, v in sd.items(): | ||||||||||||||||||||||
| nk = k | ||||||||||||||||||||||
| # if nk in merged: | ||||||||||||||||||||||
| # if merged[nk].shape != v.shape or merged[nk].dtype != v.dtype: | ||||||||||||||||||||||
| # print( | ||||||||||||||||||||||
| # f"[error] Key collision with mismatch for '{nk}' between shards " | ||||||||||||||||||||||
| # f"(existing {merged[nk].shape}/{merged[nk].dtype} vs {v.shape}/{v.dtype})" | ||||||||||||||||||||||
| # ) | ||||||||||||||||||||||
| # raise ValueError( | ||||||||||||||||||||||
| # f"Key collision with mismatch for '{nk}' between shards " | ||||||||||||||||||||||
| # f"(existing {merged[nk].shape}/{merged[nk].dtype} vs {v.shape}/{v.dtype})" | ||||||||||||||||||||||
| # ) | ||||||||||||||||||||||
| # else: | ||||||||||||||||||||||
| # # Merging tensors using merge_two_shards with heuristic fallback | ||||||||||||||||||||||
| try: | ||||||||||||||||||||||
| from torch.distributed.tensor import DTensor | ||||||||||||||||||||||
| except (ImportError, AttributeError): | ||||||||||||||||||||||
| DTensor = None | ||||||||||||||||||||||
| try: | ||||||||||||||||||||||
| from torch.distributed.tensor import ShardedTensor | ||||||||||||||||||||||
| except (ImportError, AttributeError): | ||||||||||||||||||||||
| ShardedTensor = None | ||||||||||||||||||||||
| if DTensor is not None and isinstance(v, DTensor): | ||||||||||||||||||||||
| v = v.to_local() | ||||||||||||||||||||||
| if ShardedTensor is not None and isinstance(v, ShardedTensor): | ||||||||||||||||||||||
| v = v.local_tensor() | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| if nk in merged: | ||||||||||||||||||||||
| print( | ||||||||||||||||||||||
| f"Tensors to be merged: (original) with shape {merged[nk].shape} and (newly added) with shape {v.shape}" | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| merged[nk] = merge_two_shards(merged[nk], v.detach().cpu().contiguous(), key=nk, merge_type="default") | ||||||||||||||||||||||
erictang000 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||
| else: | ||||||||||||||||||||||
| print(f"Current tensor shape: {v.shape}") | ||||||||||||||||||||||
| merged[nk] = v.detach().cpu().contiguous() | ||||||||||||||||||||||
| if not merged: | ||||||||||||||||||||||
| print("[error] No tensors found in shards") | ||||||||||||||||||||||
| raise RuntimeError("No tensors found in shards") | ||||||||||||||||||||||
| return merged | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def copy_hf_artifacts(policy_dir: Path, out_dir: Path) -> None: | ||||||||||||||||||||||
| """ | ||||||||||||||||||||||
| Copy huggingface artifacts from the policy directory to the output directory | ||||||||||||||||||||||
| - A utility function that copies huggingface artifacts from the policy directory to the output directory | ||||||||||||||||||||||
| """ | ||||||||||||||||||||||
| hf_src = policy_dir / "huggingface" | ||||||||||||||||||||||
| out_dir.mkdir(parents=True, exist_ok=True) | ||||||||||||||||||||||
| if hf_src.exists(): | ||||||||||||||||||||||
| for p in hf_src.iterdir(): | ||||||||||||||||||||||
| dst = out_dir / p.name | ||||||||||||||||||||||
| if p.is_file(): | ||||||||||||||||||||||
| shutil.copy2(p, dst) | ||||||||||||||||||||||
| elif p.is_dir(): | ||||||||||||||||||||||
| if dst.exists(): | ||||||||||||||||||||||
| shutil.rmtree(dst) | ||||||||||||||||||||||
| shutil.copytree(p, dst) | ||||||||||||||||||||||
| else: | ||||||||||||||||||||||
| print("[warn] policy/huggingface not found; you must supply a proper config/tokenizer.", file=sys.stderr) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def _materialize_for_safetensors(state_dict) -> Dict[str, torch.Tensor]: | ||||||||||||||||||||||
| """ | ||||||||||||||||||||||
| Materialize the state dict for safetensors | ||||||||||||||||||||||
| - A utility function that materializes the state dict for safetensors | ||||||||||||||||||||||
| Essentially converts all torch tensors to local tensors so they can actually be saved. | ||||||||||||||||||||||
| 1) DTensor to local tensor | ||||||||||||||||||||||
| 2) ShardedTensor to local tensor | ||||||||||||||||||||||
| We do not save meta tensors because they have no data and are not materializable. | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| Then after that, convert these local tensors to cpu tensors, and create a new dictionary of keys -> Tensors. | ||||||||||||||||||||||
| """ | ||||||||||||||||||||||
| import torch | ||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| new_sd = {} | ||||||||||||||||||||||
| for k, v in state_dict.items(): | ||||||||||||||||||||||
| if not isinstance(v, torch.Tensor): | ||||||||||||||||||||||
| new_sd[k] = v | ||||||||||||||||||||||
| continue | ||||||||||||||||||||||
| t = v | ||||||||||||||||||||||
| if type(t).__name__ == "DTensor" and hasattr(t, "to_local"): | ||||||||||||||||||||||
| t = t.to_local() | ||||||||||||||||||||||
| if type(t).__name__ == "ShardedTensor" and hasattr(t, "local_tensor"): | ||||||||||||||||||||||
| t = t.local_tensor() | ||||||||||||||||||||||
| if getattr(t, "is_meta", False): | ||||||||||||||||||||||
| raise RuntimeError(f"Tensor {k} is on meta device; load the real weights before saving.") | ||||||||||||||||||||||
| if type(t).__name__ == "FakeTensor": | ||||||||||||||||||||||
| raise RuntimeError(f"Tensor {k} is a FakeTensor; disable fake tensor mode for save") | ||||||||||||||||||||||
| if t.device.type != "cpu": | ||||||||||||||||||||||
| t = t.to("cpu", non_blocking=False) | ||||||||||||||||||||||
| t = t.detach() | ||||||||||||||||||||||
| if not t.is_contiguous(): | ||||||||||||||||||||||
| t = t.contiguous() | ||||||||||||||||||||||
| new_sd[k] = t | ||||||||||||||||||||||
| return new_sd | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def guess_hf_class(cfg: AutoConfig): | ||||||||||||||||||||||
| """ | ||||||||||||||||||||||
| Tries to find a reasonable HF class from config | ||||||||||||||||||||||
| Falls back to the AutoModel architecture if an LM head can't be detected | ||||||||||||||||||||||
| """ | ||||||||||||||||||||||
| if getattr(cfg, "is_encoder_decoder", False): | ||||||||||||||||||||||
| return AutoModelForSeq2SeqLM | ||||||||||||||||||||||
| archs = getattr(cfg, "architectures", []) or [] | ||||||||||||||||||||||
| if any(a.endswith("ForCausalLM") for a in archs): | ||||||||||||||||||||||
| return AutoModelForCausalLM | ||||||||||||||||||||||
| decoders = {"gpt2", "gpt_bigcode", "llama", "mistral", "qwen", "qwen2", "internlm", "mpt", "phi", "falcon"} | ||||||||||||||||||||||
| if getattr(cfg, "model_type", "") in decoders: | ||||||||||||||||||||||
| return AutoModelForCausalLM | ||||||||||||||||||||||
| return AutoModel | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def validate_load(out_dir: Path): | ||||||||||||||||||||||
| """ | ||||||||||||||||||||||
| Optional: sanity-load with HF to ensure the saved safetensors is consumable | ||||||||||||||||||||||
| Loads on the CPU to avoid device / dtype quirk (this may be a problem for loading on GPU which could cause data loading issues) | ||||||||||||||||||||||
| """ | ||||||||||||||||||||||
| try: | ||||||||||||||||||||||
| cfg = AutoConfig.from_pretrained(out_dir, local_files_only=True, trust_remote_code=True) | ||||||||||||||||||||||
| HFClass = guess_hf_class(cfg) | ||||||||||||||||||||||
| _ = HFClass.from_pretrained( | ||||||||||||||||||||||
| out_dir, local_files_only=True, device_map=None, dtype="auto", trust_remote_code=True | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| print("[validate] HF Load OK") | ||||||||||||||||||||||
| except Exception as e: | ||||||||||||||||||||||
| print(f"[validate][error] HF Load failed: {e} ", e) | ||||||||||||||||||||||
| raise RuntimeError("HF Load failed") | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def _untie_shared_tensors(sd) -> Dict[str, torch.Tensor]: | ||||||||||||||||||||||
| """ | ||||||||||||||||||||||
| Untie shared tensors | ||||||||||||||||||||||
| - A utility function that unties shared tensors | ||||||||||||||||||||||
| Some tensors may be shared by different keys in that the tensors they point to have the same data pointer. | ||||||||||||||||||||||
| This function takes a state dict and returns a new state dict where the shared tensors have been untied. | ||||||||||||||||||||||
| This is done by creating a new tensor (clone it) for each shared tensor. | ||||||||||||||||||||||
| This allows each key to refer to a UNIQUE tensor | ||||||||||||||||||||||
| """ | ||||||||||||||||||||||
| seen = {} | ||||||||||||||||||||||
| for k, v in list(sd.items()): | ||||||||||||||||||||||
| if not isinstance(v, torch.Tensor): | ||||||||||||||||||||||
| continue | ||||||||||||||||||||||
| try: | ||||||||||||||||||||||
| ptr = v.storage().data_ptr() | ||||||||||||||||||||||
| except Exception: | ||||||||||||||||||||||
| continue | ||||||||||||||||||||||
| if ptr in seen: | ||||||||||||||||||||||
| sd[k] = v.clone() | ||||||||||||||||||||||
| else: | ||||||||||||||||||||||
| seen[ptr] = (k, v) | ||||||||||||||||||||||
| return sd | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def main(): | ||||||||||||||||||||||
| ap = argparse.ArgumentParser(description="Convert FSDP checkpoint shards to a HuggingFace safetensors model.") | ||||||||||||||||||||||
| ap.add_argument( | ||||||||||||||||||||||
| "--ckpt-dir", type=str, required=True, help="Path to the checkpoint directory, containing trainer_state.pt" | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| ap.add_argument("--out-dir", type=str, required=True, help="Output for HF model folder") | ||||||||||||||||||||||
| ap.add_argument( | ||||||||||||||||||||||
| "--validate-load", action="store_true", help="Try loading with the Transformers Module after saving" | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| args = ap.parse_args() | ||||||||||||||||||||||
| ckpt_dir = Path(args.ckpt_dir).resolve() | ||||||||||||||||||||||
| output_dir = Path(args.out_dir).resolve() | ||||||||||||||||||||||
| policy_dir = find_policy_dir(ckpt_dir) | ||||||||||||||||||||||
| shards_paths = get_model_shards(policy_dir) | ||||||||||||||||||||||
| print(f"[info] Found {len(shards_paths)} model shard(s). ") | ||||||||||||||||||||||
| for s in shards_paths: | ||||||||||||||||||||||
| print(f"[info] - {s}") | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| print("[info] Merging shards...") | ||||||||||||||||||||||
| state_dict = merge_shards(shards_paths) | ||||||||||||||||||||||
| print(f"[info] Merged {len(state_dict)} tensors.") | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| copy_hf_artifacts(policy_dir, output_dir) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| clean_sd = _materialize_for_safetensors(state_dict) | ||||||||||||||||||||||
| clean_sd = _untie_shared_tensors(clean_sd) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| # save_file(clean_sd, str(weights_path)) | ||||||||||||||||||||||
| # print(f"[success] Saved weights to {weights_path}") | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| cfg = AutoConfig.from_pretrained(output_dir, local_files_only=True, trust_remote_code=True) | ||||||||||||||||||||||
| HFClass = guess_hf_class(cfg) | ||||||||||||||||||||||
| hf_model = HFClass.from_config(cfg) | ||||||||||||||||||||||
| hf_model.save_pretrained( | ||||||||||||||||||||||
| save_directory=output_dir, | ||||||||||||||||||||||
| state_dict=clean_sd, | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| if args.validate_load: | ||||||||||||||||||||||
| validate_load(output_dir) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| if __name__ == "__main__": | ||||||||||||||||||||||
| main() | ||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The script uses
printfor status and error messages. Using the standardloggingmodule is a better practice for scripts like this. It provides more flexibility, such as controlling log levels (INFO, WARN, ERROR), formatting, and directing output to different destinations (e.g., file, stderr) without changing the code.