Skip to content

Commit 35f0908

Browse files
feat: add remove optim func to checkpoint_util.py
Signed-off-by: yashasvi <[email protected]>
1 parent 092a437 commit 35f0908

File tree

1 file changed

+67
-8
lines changed

1 file changed

+67
-8
lines changed
Lines changed: 67 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
# Standard
66
from pathlib import Path
7-
from typing import Any
7+
from typing import Any, Iterable, Set
88
import argparse
99
import shutil
1010

@@ -29,7 +29,7 @@ def cast_fp32_to_bf16(x: Any, *, in_optim: bool = False) -> Any:
2929
out = {}
3030
for k, v in x.items():
3131
k_lower = k.lower() if isinstance(k, str) else ""
32-
child_in_optim = in_optim or (k_lower in OPTIM_ROOT_KEYS)
32+
child_in_optim = in_optim or any(k_lower.startswith(root) for root in OPTIM_ROOT_KEYS)
3333
out[k] = cast_fp32_to_bf16(v, in_optim=child_in_optim)
3434
return out
3535
if isinstance(x, (list, tuple)):
@@ -46,8 +46,8 @@ def convert_pt_pth(inp: Path, out: Path) -> None:
4646

4747

4848
def is_optim_tensor_name(name: str) -> bool:
49-
parts = (name or "").lower().replace("/", ".").split(".")
50-
return bool(parts) and parts[0] in OPTIM_ROOT_KEYS
49+
first = (name or "").lower().replace("/", ".").split(".")[0]
50+
return any(first.startswith(root) for root in OPTIM_ROOT_KEYS)
5151

5252

5353
def convert_safetensors_file(inp: Path, out: Path) -> None:
@@ -65,6 +65,27 @@ def convert_safetensors_file(inp: Path, out: Path) -> None:
6565
print(f"[safetensors] wrote: {out}")
6666

6767

68+
def slim_copy_dir_skip_only(src: Path, dst: Path, skip_names: Iterable[str]) -> None:
69+
"""
70+
Copy everything from src -> dst EXCEPT files whose names are in skip_names.
71+
Directories are copied entirely unless their name is in skip_names.
72+
"""
73+
dst.mkdir(parents=True, exist_ok=True)
74+
skip_set: Set[str] = set(skip_names)
75+
76+
for item in src.iterdir():
77+
if item.name in skip_set:
78+
continue
79+
target = dst / item.name
80+
if item.is_file():
81+
shutil.copy2(item, target)
82+
elif item.is_dir():
83+
shutil.copytree(item, target, dirs_exist_ok=True)
84+
print(
85+
f"[slim] wrote: {dst} (skipped: {', '.join(skip_set) if skip_set else 'none'})"
86+
)
87+
88+
6889
def convert_dir_of_safetensors(src: Path, dst: Path) -> None:
6990
"""Convert all .safetensors in a directory; copy other files as-is."""
7091
dst.mkdir(parents=True, exist_ok=True)
@@ -82,16 +103,45 @@ def convert_dir_of_safetensors(src: Path, dst: Path) -> None:
82103

83104
def main():
84105
ap = argparse.ArgumentParser(
85-
description="Convert FP32 tensors to BF16 (skip optimizer states)."
106+
description="Convert FP32 tensors to BF16 (skips optimizer states)."
86107
)
87108
ap.add_argument(
88109
"input",
89110
type=Path,
90111
help="Input: .pt/.pth, .safetensors, or HF directory with .safetensors",
91112
)
92113
ap.add_argument("output", type=Path, help="Output file or directory")
114+
115+
ap.add_argument(
116+
"--slim",
117+
action="store_true",
118+
help="For directory inputs: after conversion, copy everything \
119+
except files listed in --skip (default: optimizer.pt).",
120+
)
121+
ap.add_argument(
122+
"--slim-only",
123+
action="store_true",
124+
help="For directory inputs: DO NOT convert; just copy everything \
125+
except files in --skip.",
126+
)
127+
ap.add_argument(
128+
"--skip",
129+
default="optimizer.pt",
130+
help="Comma-separated file names to skip during slimming \
131+
(applies to --slim or --slim-only). Default: optimizer.pt",
132+
)
133+
93134
args = ap.parse_args()
94135

136+
if args.slim and args.slim_only:
137+
raise SystemExit("Choose at most one of: --slim or --slim-only.")
138+
139+
skip_list = (
140+
[s.strip() for s in args.skip.split(",")]
141+
if (args.slim or args.slim_only)
142+
else []
143+
)
144+
95145
p = args.input
96146
if p.is_file():
97147
sfx = p.suffix.lower()
@@ -111,10 +161,19 @@ def main():
111161
else:
112162
raise SystemExit(f"Unsupported file type: {p}")
113163
elif p.is_dir():
114-
if any(x.suffix == ".safetensors" for x in p.iterdir()):
115-
convert_dir_of_safetensors(p, args.output)
116-
else:
164+
if not any(x.suffix == ".safetensors" for x in p.iterdir()):
117165
raise SystemExit("Directory has no .safetensors files.")
166+
if args.slim_only:
167+
slim_copy_dir_skip_only(p, args.output, skip_list)
168+
else:
169+
convert_dir_of_safetensors(p, args.output)
170+
if args.slim:
171+
tmp = args.output.parent / (args.output.name + "_tmp_slim")
172+
if tmp.exists():
173+
shutil.rmtree(tmp)
174+
slim_copy_dir_skip_only(args.output, tmp, skip_list)
175+
shutil.rmtree(args.output)
176+
tmp.rename(args.output)
118177
else:
119178
raise SystemExit(f"Not found: {p}")
120179

0 commit comments

Comments
 (0)