Skip to content

Commit 9c2bf84

Browse files
feat: add remove optim func to checkpoint_util.py
Signed-off-by: yashasvi <[email protected]>
1 parent 092a437 commit 9c2bf84

File tree

1 file changed

+69
-8
lines changed

1 file changed

+69
-8
lines changed
Lines changed: 69 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
# Standard
66
from pathlib import Path
7-
from typing import Any
7+
from typing import Any, Iterable, Set
88
import argparse
99
import shutil
1010

@@ -29,7 +29,9 @@ def cast_fp32_to_bf16(x: Any, *, in_optim: bool = False) -> Any:
2929
out = {}
3030
for k, v in x.items():
3131
k_lower = k.lower() if isinstance(k, str) else ""
32-
child_in_optim = in_optim or (k_lower in OPTIM_ROOT_KEYS)
32+
child_in_optim = in_optim or any(
33+
k_lower.startswith(root) for root in OPTIM_ROOT_KEYS
34+
)
3335
out[k] = cast_fp32_to_bf16(v, in_optim=child_in_optim)
3436
return out
3537
if isinstance(x, (list, tuple)):
@@ -46,8 +48,8 @@ def convert_pt_pth(inp: Path, out: Path) -> None:
4648

4749

4850
def is_optim_tensor_name(name: str) -> bool:
49-
parts = (name or "").lower().replace("/", ".").split(".")
50-
return bool(parts) and parts[0] in OPTIM_ROOT_KEYS
51+
first = (name or "").lower().replace("/", ".").split(".")[0]
52+
return any(first.startswith(root) for root in OPTIM_ROOT_KEYS)
5153

5254

5355
def convert_safetensors_file(inp: Path, out: Path) -> None:
@@ -65,6 +67,27 @@ def convert_safetensors_file(inp: Path, out: Path) -> None:
6567
print(f"[safetensors] wrote: {out}")
6668

6769

70+
def slim_copy_dir_skip_only(src: Path, dst: Path, skip_names: Iterable[str]) -> None:
71+
"""
72+
Copy everything from src -> dst EXCEPT files whose names are in skip_names.
73+
Directories are copied entirely unless their name is in skip_names.
74+
"""
75+
dst.mkdir(parents=True, exist_ok=True)
76+
skip_set: Set[str] = set(skip_names)
77+
78+
for item in src.iterdir():
79+
if item.name in skip_set:
80+
continue
81+
target = dst / item.name
82+
if item.is_file():
83+
shutil.copy2(item, target)
84+
elif item.is_dir():
85+
shutil.copytree(item, target, dirs_exist_ok=True)
86+
print(
87+
f"[slim] wrote: {dst} (skipped: {', '.join(skip_set) if skip_set else 'none'})"
88+
)
89+
90+
6891
def convert_dir_of_safetensors(src: Path, dst: Path) -> None:
6992
"""Convert all .safetensors in a directory; copy other files as-is."""
7093
dst.mkdir(parents=True, exist_ok=True)
@@ -82,16 +105,45 @@ def convert_dir_of_safetensors(src: Path, dst: Path) -> None:
82105

83106
def main():
84107
ap = argparse.ArgumentParser(
85-
description="Convert FP32 tensors to BF16 (skip optimizer states)."
108+
description="Convert FP32 tensors to BF16 (skips optimizer states)."
86109
)
87110
ap.add_argument(
88111
"input",
89112
type=Path,
90113
help="Input: .pt/.pth, .safetensors, or HF directory with .safetensors",
91114
)
92115
ap.add_argument("output", type=Path, help="Output file or directory")
116+
117+
ap.add_argument(
118+
"--slim",
119+
action="store_true",
120+
help="For directory inputs: after conversion, copy everything \
121+
except files listed in --skip (default: optimizer.pt).",
122+
)
123+
ap.add_argument(
124+
"--slim-only",
125+
action="store_true",
126+
help="For directory inputs: DO NOT convert; just copy everything \
127+
except files in --skip.",
128+
)
129+
ap.add_argument(
130+
"--skip",
131+
default="optimizer.pt",
132+
help="Comma-separated file names to skip during slimming \
133+
(applies to --slim or --slim-only). Default: optimizer.pt",
134+
)
135+
93136
args = ap.parse_args()
94137

138+
if args.slim and args.slim_only:
139+
raise SystemExit("Choose at most one of: --slim or --slim-only.")
140+
141+
skip_list = (
142+
[s.strip() for s in args.skip.split(",")]
143+
if (args.slim or args.slim_only)
144+
else []
145+
)
146+
95147
p = args.input
96148
if p.is_file():
97149
sfx = p.suffix.lower()
@@ -111,10 +163,19 @@ def main():
111163
else:
112164
raise SystemExit(f"Unsupported file type: {p}")
113165
elif p.is_dir():
114-
if any(x.suffix == ".safetensors" for x in p.iterdir()):
115-
convert_dir_of_safetensors(p, args.output)
116-
else:
166+
if not any(x.suffix == ".safetensors" for x in p.iterdir()):
117167
raise SystemExit("Directory has no .safetensors files.")
168+
if args.slim_only:
169+
slim_copy_dir_skip_only(p, args.output, skip_list)
170+
else:
171+
convert_dir_of_safetensors(p, args.output)
172+
if args.slim:
173+
tmp = args.output.parent / (args.output.name + "_tmp_slim")
174+
if tmp.exists():
175+
shutil.rmtree(tmp)
176+
slim_copy_dir_skip_only(args.output, tmp, skip_list)
177+
shutil.rmtree(args.output)
178+
tmp.rename(args.output)
118179
else:
119180
raise SystemExit(f"Not found: {p}")
120181

0 commit comments

Comments
 (0)