Skip to content

Commit 54762b7

Browse files
authored
Merge pull request #1101 from cklxx/codex/fix-issue-1094-in-slime-repository
Added FSDP checkpoint handling to convert_torch_dist_to_hf.py
2 parents 34bb0cd + 75db65d commit 54762b7

File tree

3 files changed

+195
-0
lines changed

3 files changed

+195
-0
lines changed

docs/en/get_started/quick_start.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,14 @@ PYTHONPATH=/root/Megatron-LM python tools/convert_torch_dist_to_hf.py \
108108

109109
Note that as Megatron will do padding to embedding for better performance, it may happen that the converted embedding is not correct. In that case, please manually set `--vocab-size` during convertion.
110110

111+
For FSDP checkpoints (without `common.pt`), use the dedicated conversion script. Point `--input-dir` to the checkpoint directory (e.g. `iter_xxx` or `iter_xxx/model`) and provide the original Hugging Face directory:
112+
113+
```bash
114+
python tools/convert_fsdp_to_hf.py \
115+
--input-dir /path/to/fsdp_ckpt/iter_xxx \
116+
--output-dir /root/fsdp-converted \
117+
--origin-hf-dir /root/GLM-Z1-9B-0414
118+
```
111119

112120
## Training Script and Parameter Overview
113121

docs/zh/get_started/quick_start.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,15 @@ PYTHONPATH=/root/Megatron-LM python tools/convert_torch_dist_to_hf.py \
108108

109109
由于 Megatron 会对 embedding 做 padding,可能会出现转换出来的权重的 embedding 形状不匹配的问题。这时需要在转换时设置 `--vocab-size`
110110

111+
对于使用 FSDP 后端训练并保存的检查点(目录中没有 `common.pt` 的情况),请使用专门的转换脚本。将 `--input-dir` 指向检查点目录(例如 `iter_xxx``iter_xxx/model`),并提供原始 Hugging Face 模型路径:
112+
113+
```bash
114+
python tools/convert_fsdp_to_hf.py \
115+
--input-dir /path/to/fsdp_ckpt/iter_xxx \
116+
--output-dir /root/fsdp-converted \
117+
--origin-hf-dir /root/GLM-Z1-9B-0414
118+
```
119+
111120
## 训练脚本与参数概览
112121

113122
完成上述准备工作后,即可运行训练脚本。

tools/convert_fsdp_to_hf.py

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
import argparse
2+
import os
3+
import pickle
4+
import shutil
5+
import time
6+
7+
import torch
8+
import torch.distributed.checkpoint as dist_cp
9+
from transformers import AutoConfig, AutoModelForCausalLM
10+
from typing_extensions import override
11+
12+
13+
class UnpicklerWrapper(pickle.Unpickler):
14+
@override
15+
def find_class(self, mod_name, name):
16+
class DummyClass:
17+
def __init__(self, *args, **kwargs):
18+
pass
19+
20+
if mod_name.startswith("megatron") or mod_name.startswith("glm"):
21+
return DummyClass
22+
return super().find_class(mod_name, name)
23+
24+
25+
class WrappedStorageReader(dist_cp.FileSystemReader):
26+
@override
27+
def read_metadata(self):
28+
path = self.fs.concat_path(self.path, ".metadata")
29+
with self.fs.create_stream(path, "rb") as metadata_file:
30+
metadata = UnpicklerWrapper(metadata_file).load()
31+
if getattr(metadata, "storage_meta", None) is None:
32+
metadata.storage_meta = dist_cp.StorageMeta()
33+
metadata.storage_meta.load_id = self.load_id
34+
if metadata.planner_data is None:
35+
metadata.planner_data = {}
36+
return metadata
37+
38+
39+
class EmptyStateDictLoadPlanner(dist_cp.default_planner.DefaultLoadPlanner):
40+
@override
41+
def set_up_planner(
42+
self,
43+
state_dict: dist_cp.metadata.STATE_DICT_TYPE,
44+
metadata: dist_cp.metadata.Metadata | None = None,
45+
is_coordinator: bool = False,
46+
) -> None:
47+
for k, v in metadata.state_dict_metadata.items():
48+
if "optimizer" in k:
49+
continue
50+
print(f"find {k} in torch_dist ckpt")
51+
if isinstance(v, dist_cp.metadata.TensorStorageMetadata):
52+
v = torch.empty(v.size, dtype=v.properties.dtype) # type: ignore[assignment]
53+
state_dict[k] = v
54+
super().set_up_planner(state_dict, metadata, is_coordinator)
55+
56+
57+
def _detect_model_dir(input_dir: str) -> str:
58+
model_dir = os.path.join(input_dir, "model")
59+
return model_dir if os.path.isdir(model_dir) else input_dir
60+
61+
62+
def _load_fsdp_state_dict(input_dir: str) -> dict[str, torch.Tensor]:
63+
state_dict: dict[str, torch.Tensor] = {}
64+
dist_cp.state_dict_loader._load_state_dict(
65+
state_dict,
66+
storage_reader=WrappedStorageReader(input_dir),
67+
planner=EmptyStateDictLoadPlanner(),
68+
no_dist=True,
69+
)
70+
return state_dict
71+
72+
73+
def _get_candidate_prefixes(keys: list[str]) -> list[str]:
74+
predefined = [
75+
"model_state.model.",
76+
"model_state.",
77+
"model.",
78+
"module.",
79+
"",
80+
]
81+
82+
detected: set[str] = set()
83+
for key in keys:
84+
for prefix in predefined:
85+
if prefix and key.startswith(prefix):
86+
detected.add(prefix)
87+
88+
# Always keep empty string as a fall back option for exact match.
89+
detected.add("")
90+
# Preserve predefined order while keeping only detected prefixes.
91+
return [p for p in predefined if p in detected]
92+
93+
94+
def _strip_best_prefix(keys: list[str], target_keys: set[str]) -> tuple[str, int]:
95+
best_prefix = ""
96+
best_match = -1
97+
98+
for prefix in _get_candidate_prefixes(keys):
99+
mapped_keys = {k.removeprefix(prefix) for k in keys}
100+
match_count = len(mapped_keys & target_keys)
101+
if match_count > best_match:
102+
best_match = match_count
103+
best_prefix = prefix
104+
105+
return best_prefix, best_match
106+
107+
108+
def _convert_fsdp_to_hf(
109+
origin_hf_dir: str,
110+
input_dir: str,
111+
output_dir: str,
112+
) -> None:
113+
print(f"loading FSDP model from {input_dir}")
114+
t = time.time()
115+
state_dict = _load_fsdp_state_dict(input_dir)
116+
print(f"FSDP model loaded in {time.time()-t:.2f} sec.")
117+
118+
tensor_items = {k: v for k, v in state_dict.items() if isinstance(v, torch.Tensor)}
119+
120+
config = AutoConfig.from_pretrained(origin_hf_dir, trust_remote_code=True)
121+
hf_model = AutoModelForCausalLM.from_config(config)
122+
target_keys = set(hf_model.state_dict().keys())
123+
124+
best_prefix, best_match = _strip_best_prefix(list(tensor_items.keys()), target_keys)
125+
total_keys = len(tensor_items)
126+
127+
print(f"Using prefix '{best_prefix}' for key mapping. " f"Matched {best_match}/{total_keys} parameter keys.")
128+
129+
model_state = {k.removeprefix(best_prefix): v for k, v in tensor_items.items()}
130+
131+
if not model_state:
132+
raise ValueError(
133+
"No model weights found in checkpoint. "
134+
"Please pass the checkpoint directory (e.g. iter_xxx or iter_xxx/model)."
135+
)
136+
137+
missing, unexpected = hf_model.load_state_dict(model_state, strict=False)
138+
print(f"Missing keys: {missing}\nUnexpected keys: {unexpected}")
139+
140+
os.makedirs(output_dir, exist_ok=True)
141+
hf_model.save_pretrained(output_dir, safe_serialization=True)
142+
print(f"Model weights saved to {output_dir}")
143+
144+
145+
def copy_assets(origin_hf_dir: str, output_dir: str) -> None:
146+
for filename in os.listdir(origin_hf_dir):
147+
if filename == "model.safetensors.index.json" or filename.endswith(".safetensors"):
148+
continue
149+
origin_filename = os.path.join(origin_hf_dir, filename)
150+
if not os.path.isfile(origin_filename):
151+
print(f"Skip {filename}, not a file.")
152+
continue
153+
src, dst = origin_filename, os.path.join(output_dir, filename)
154+
print(f"copy from {src} to {dst}")
155+
shutil.copy(src, dst)
156+
157+
158+
if __name__ == "__main__":
159+
parser = argparse.ArgumentParser()
160+
parser.add_argument("--input-dir", type=str, required=True)
161+
parser.add_argument("--output-dir", type=str, required=True)
162+
parser.add_argument(
163+
"--origin-hf-dir",
164+
type=str,
165+
required=True,
166+
help="The original Hugging Face model directory to load config/tokenizer assets.",
167+
)
168+
parser.add_argument(
169+
"-f", "--force", action="store_true", help="Force overwrite the output directory if it exists."
170+
)
171+
args = parser.parse_args()
172+
173+
if os.path.exists(args.output_dir) and not args.force:
174+
raise ValueError(f"Output directory {args.output_dir} already exists. Use --force to overwrite it.")
175+
176+
model_dir = _detect_model_dir(args.input_dir)
177+
_convert_fsdp_to_hf(args.origin_hf_dir, model_dir, args.output_dir)
178+
copy_assets(args.origin_hf_dir, args.output_dir)

0 commit comments

Comments
 (0)