Skip to content

Commit 230f6b3

Browse files
Merge branch 'main' into chat_template_path
2 parents 4f9e9de + d9ee35f commit 230f6b3

File tree

1 file changed

+310
-0
lines changed

1 file changed

+310
-0
lines changed

scripts/checkpoint_utils.py

Lines changed: 310 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,310 @@
1+
#!/usr/bin/env python3
2+
# Checkpoint utilities (unified --inplace):
3+
# - Default: copy INPUT -> OUTPUT unchanged
4+
# - --convert-model-to-bf16: convert model FP32 -> BF16 (optimizer tensors remain FP32)
5+
# - --no-optimizer: when writing outputs, drop optimizer files/dirs (defaults + --drop-files)
6+
# - --drop-files: comma-separated extra file/dir (used with --no-optimizer, and with --inplace)
7+
# - --inplace: perform conversion and/or dropping directly in INPUT (destructive)
8+
9+
10+
# Standard
11+
from pathlib import Path
12+
from typing import Any, Iterable, Set
13+
import argparse
14+
import os
15+
import shutil
16+
17+
# Third Party
18+
import torch
19+
20+
try:
21+
# Third Party
22+
from safetensors.torch import safe_open, save_file
23+
24+
HAS_SAFETENSORS = True
25+
except ImportError:
26+
HAS_SAFETENSORS = False
27+
28+
OPTIM_ROOT_KEYS = {"optimizer", "optim", "opt_state"}
29+
30+
DEFAULT_OPTIM_DROPS = {"optimizer.pt", "optimizer", "optimizer_0", "optimizer_1"}
31+
32+
33+
def _atomic_replace(tmp: Path, dst: Path) -> None:
34+
dst.parent.mkdir(parents=True, exist_ok=True)
35+
os.replace(str(tmp), str(dst)) # atomic on POSIX
36+
37+
38+
def cast_fp32_to_bf16(x: Any, *, in_optim: bool = False) -> Any:
39+
"""Recursively cast float32 tensors to bfloat16, skipping optimizer subtrees."""
40+
if isinstance(x, torch.Tensor):
41+
return x if in_optim or x.dtype != torch.float32 else x.to(torch.bfloat16)
42+
if isinstance(x, dict):
43+
out = {}
44+
for k, v in x.items():
45+
k_lower = k.lower() if isinstance(k, str) else ""
46+
child_in_optim = in_optim or any(
47+
k_lower.startswith(root) for root in OPTIM_ROOT_KEYS
48+
)
49+
out[k] = cast_fp32_to_bf16(v, in_optim=child_in_optim)
50+
return out
51+
if isinstance(x, (list, tuple)):
52+
return type(x)(cast_fp32_to_bf16(v, in_optim=in_optim) for v in x)
53+
return x
54+
55+
56+
def is_optim_tensor_name(name: str) -> bool:
57+
first = (name or "").lower().replace("/", ".").split(".")[0]
58+
return any(first.startswith(root) for root in OPTIM_ROOT_KEYS)
59+
60+
61+
def convert_pt_pth(inp: Path, out: Path) -> None:
62+
data = torch.load(inp, map_location="cpu")
63+
data = cast_fp32_to_bf16(data)
64+
out.parent.mkdir(parents=True, exist_ok=True)
65+
torch.save(data, out)
66+
print(f"[pt/pth] wrote: {out}")
67+
68+
69+
def convert_pt_pth_inplace(inp: Path) -> None:
70+
tmp = inp.with_suffix(inp.suffix + ".tmp")
71+
convert_pt_pth(inp, tmp)
72+
_atomic_replace(tmp, inp)
73+
print(f"[pt/pth][inplace] updated: {inp}")
74+
75+
76+
def convert_safetensors_file(inp: Path, out: Path) -> None:
77+
if not HAS_SAFETENSORS:
78+
raise RuntimeError("safetensors not installed. pip install safetensors")
79+
tensors = {}
80+
with safe_open(str(inp), framework="pt", device="cpu") as f:
81+
for key in f.keys():
82+
t = f.get_tensor(key)
83+
if t.dtype == torch.float32 and not is_optim_tensor_name(key):
84+
t = t.to(torch.bfloat16)
85+
tensors[key] = t
86+
out.parent.mkdir(parents=True, exist_ok=True)
87+
save_file(tensors, str(out), metadata={"converted_to": "bfloat16"})
88+
print(f"[safetensors] wrote: {out}")
89+
90+
91+
def convert_safetensors_file_inplace(inp: Path) -> None:
92+
tmp = inp.with_suffix(inp.suffix + ".tmp")
93+
convert_safetensors_file(inp, tmp)
94+
_atomic_replace(tmp, inp)
95+
print(f"[safetensors][inplace] updated: {inp}")
96+
97+
98+
def convert_dir_of_safetensors(src: Path, dst: Path) -> None:
99+
"""Convert all .safetensors in a directory; copy other files as-is."""
100+
dst.mkdir(parents=True, exist_ok=True)
101+
for item in src.iterdir():
102+
if item.suffix == ".safetensors":
103+
convert_safetensors_file(item, dst / item.name)
104+
else:
105+
target = dst / item.name
106+
if item.is_file():
107+
shutil.copy2(item, target)
108+
elif item.is_dir():
109+
shutil.copytree(item, target, dirs_exist_ok=True)
110+
print(f"[dir] wrote: {dst}")
111+
112+
113+
def convert_dir_of_safetensors_inplace(src: Path) -> None:
114+
"""Convert all .safetensors files in-place within `src`."""
115+
count = 0
116+
for item in src.iterdir():
117+
if item.suffix == ".safetensors":
118+
convert_safetensors_file_inplace(item)
119+
count += 1
120+
if count == 0:
121+
raise SystemExit("Directory has no .safetensors files.")
122+
print(f"[dir][inplace] converted {count} shard(s) in: {src}")
123+
124+
125+
def _name_matches(name: str, patterns: Set[str]) -> bool:
126+
"""Exact-name match (simple and predictable)."""
127+
return name in patterns
128+
129+
130+
def copy_dir_drop(src: Path, dst: Path, drop_names: Iterable[str]) -> None:
131+
"""Copy directory but drop certain files/dirs by exact name."""
132+
dst.mkdir(parents=True, exist_ok=True)
133+
drop_set: Set[str] = set(drop_names)
134+
for item in src.iterdir():
135+
if _name_matches(item.name, drop_set):
136+
continue
137+
target = dst / item.name
138+
if item.is_file():
139+
shutil.copy2(item, target)
140+
elif item.is_dir():
141+
shutil.copytree(item, target, dirs_exist_ok=True)
142+
print(
143+
f"[copy-drop] wrote: {dst} (dropped: {', '.join(sorted(drop_set)) if drop_set else 'none'})"
144+
)
145+
146+
147+
def prune_dir_inplace(src: Path, drop_names: Iterable[str]) -> None:
148+
"""Delete top-level files/dirs in `src` whose names match `drop_names`. Destructive."""
149+
drop_set: Set[str] = set(drop_names)
150+
removed = []
151+
for item in src.iterdir():
152+
if _name_matches(item.name, drop_set):
153+
if item.is_file():
154+
item.unlink()
155+
elif item.is_dir():
156+
shutil.rmtree(item)
157+
removed.append(item.name)
158+
print(
159+
f"[inplace-drop] removed: {', '.join(sorted(removed)) if removed else 'nothing'}"
160+
)
161+
162+
163+
def copy_any(src: Path, dst: Path) -> None:
164+
"""Pure copy (no dtype changes, no dropping)."""
165+
if src.is_file():
166+
dst.parent.mkdir(parents=True, exist_ok=True)
167+
shutil.copy2(src, dst if dst.suffix else dst / src.name)
168+
elif src.is_dir():
169+
dst.mkdir(parents=True, exist_ok=True)
170+
for item in src.iterdir():
171+
target = dst / item.name
172+
if item.is_file():
173+
shutil.copy2(item, target)
174+
elif item.is_dir():
175+
shutil.copytree(item, target, dirs_exist_ok=True)
176+
else:
177+
raise SystemExit(f"Not found: {src}")
178+
print(f"[copy] wrote: {dst}")
179+
180+
181+
def main():
182+
ap = argparse.ArgumentParser(
183+
description="Checkpoint utilities: copy by default; \
184+
optionally convert FP32->BF16 and/or drop optimizer files. "
185+
"Use --inplace to modify INPUT directly."
186+
)
187+
ap.add_argument("input", type=Path, help="Input file or directory")
188+
ap.add_argument("output", type=Path, help="Output file or directory")
189+
190+
ap.add_argument(
191+
"--convert-model-to-bf16",
192+
action="store_true",
193+
help="Convert FP32 -> BF16 for model tensors; optimizer tensors remain FP32.",
194+
)
195+
ap.add_argument(
196+
"--no-optimizer",
197+
action="store_true",
198+
help="When writing outputs, drop optimizer files/dirs (defaults + --drop-files).",
199+
)
200+
ap.add_argument(
201+
"--drop-files",
202+
default="",
203+
help="Comma-separated extra file/dir names to drop \
204+
(works with --no-optimizer and/or --inplace).",
205+
)
206+
ap.add_argument(
207+
"--inplace",
208+
action="store_true",
209+
help="Perform operations directly on INPUT (destructive). For files: overwrite in place; "
210+
"for directories: convert shards in-place and/or delete dropped names.",
211+
)
212+
213+
args = ap.parse_args()
214+
215+
p = args.input
216+
217+
user_drops = {s.strip() for s in args.drop_files.split(",") if s.strip()}
218+
if args.no_optimizer:
219+
drop_set = DEFAULT_OPTIM_DROPS | user_drops
220+
else:
221+
drop_set = user_drops
222+
223+
if args.inplace:
224+
if not p.exists():
225+
raise SystemExit(f"Not found: {p}")
226+
227+
if args.convert_model_to_bf16:
228+
if p.is_file():
229+
sfx = p.suffix.lower()
230+
if sfx in {".pt", ".pth"}:
231+
convert_pt_pth_inplace(p)
232+
elif sfx == ".safetensors":
233+
convert_safetensors_file_inplace(p)
234+
else:
235+
raise SystemExit(
236+
f"Unsupported file type for inplace conversion: {p}"
237+
)
238+
elif p.is_dir():
239+
convert_dir_of_safetensors_inplace(p)
240+
else:
241+
raise SystemExit(f"Not found: {p}")
242+
243+
if drop_set:
244+
if not p.is_dir():
245+
print(
246+
"[inplace] --drop-files applies to directories; skipping for file input."
247+
)
248+
else:
249+
prune_dir_inplace(p, drop_set)
250+
251+
print("Done.")
252+
return
253+
254+
if not args.convert_model_to_bf16 and not args.no_optimizer and not drop_set:
255+
copy_any(p, args.output)
256+
print("Done.")
257+
return
258+
259+
if p.is_file():
260+
sfx = p.suffix.lower()
261+
if args.convert_model_to_bf16:
262+
if sfx in {".pt", ".pth"}:
263+
convert_pt_pth(p, args.output)
264+
elif sfx == ".safetensors":
265+
out = (
266+
args.output
267+
if args.output.suffix == ".safetensors"
268+
else (args.output / p.name)
269+
)
270+
convert_safetensors_file(p, out)
271+
else:
272+
raise SystemExit(f"Unsupported file type: {p}")
273+
else:
274+
copy_any(p, args.output)
275+
print("Done.")
276+
return
277+
278+
if p.is_dir():
279+
if args.convert_model_to_bf16:
280+
if not any(x.suffix == ".safetensors" for x in p.iterdir()):
281+
raise SystemExit("Directory has no .safetensors files.")
282+
convert_dir_of_safetensors(p, args.output)
283+
if args.no_optimizer or drop_set:
284+
tmp = args.output.parent / (args.output.name + "_tmp_drop")
285+
if tmp.exists():
286+
shutil.rmtree(tmp)
287+
copy_dir_drop(
288+
args.output,
289+
tmp,
290+
DEFAULT_OPTIM_DROPS | drop_set if args.no_optimizer else drop_set,
291+
)
292+
shutil.rmtree(args.output)
293+
tmp.rename(args.output)
294+
else:
295+
if args.no_optimizer or drop_set:
296+
copy_dir_drop(
297+
p,
298+
args.output,
299+
DEFAULT_OPTIM_DROPS | drop_set if args.no_optimizer else drop_set,
300+
)
301+
else:
302+
copy_any(p, args.output)
303+
print("Done.")
304+
return
305+
306+
raise SystemExit(f"Not found: {p}")
307+
308+
309+
if __name__ == "__main__":
310+
main()

0 commit comments

Comments
 (0)