44
55# Standard
66from pathlib import Path
7- from typing import Any
7+ from typing import Any , Iterable , Set
88import argparse
99import shutil
1010
@@ -29,7 +29,7 @@ def cast_fp32_to_bf16(x: Any, *, in_optim: bool = False) -> Any:
2929 out = {}
3030 for k , v in x .items ():
3131 k_lower = k .lower () if isinstance (k , str ) else ""
32- child_in_optim = in_optim or (k_lower in OPTIM_ROOT_KEYS )
32+ child_in_optim = in_optim or any (k_lower . startswith ( root ) for root in OPTIM_ROOT_KEYS )
3333 out [k ] = cast_fp32_to_bf16 (v , in_optim = child_in_optim )
3434 return out
3535 if isinstance (x , (list , tuple )):
@@ -46,8 +46,8 @@ def convert_pt_pth(inp: Path, out: Path) -> None:
4646
4747
4848def is_optim_tensor_name (name : str ) -> bool :
49- parts = (name or "" ).lower ().replace ("/" , "." ).split ("." )
50- return bool ( parts ) and parts [ 0 ] in OPTIM_ROOT_KEYS
49+ first = (name or "" ).lower ().replace ("/" , "." ).split ("." )[ 0 ]
50+ return any ( first . startswith ( root ) for root in OPTIM_ROOT_KEYS )
5151
5252
5353def convert_safetensors_file (inp : Path , out : Path ) -> None :
@@ -65,6 +65,27 @@ def convert_safetensors_file(inp: Path, out: Path) -> None:
6565 print (f"[safetensors] wrote: { out } " )
6666
6767
68+ def slim_copy_dir_skip_only (src : Path , dst : Path , skip_names : Iterable [str ]) -> None :
69+ """
70+ Copy everything from src -> dst EXCEPT files whose names are in skip_names.
71+ Directories are copied entirely unless their name is in skip_names.
72+ """
73+ dst .mkdir (parents = True , exist_ok = True )
74+ skip_set : Set [str ] = set (skip_names )
75+
76+ for item in src .iterdir ():
77+ if item .name in skip_set :
78+ continue
79+ target = dst / item .name
80+ if item .is_file ():
81+ shutil .copy2 (item , target )
82+ elif item .is_dir ():
83+ shutil .copytree (item , target , dirs_exist_ok = True )
84+ print (
85+ f"[slim] wrote: { dst } (skipped: { ', ' .join (skip_set ) if skip_set else 'none' } )"
86+ )
87+
88+
6889def convert_dir_of_safetensors (src : Path , dst : Path ) -> None :
6990 """Convert all .safetensors in a directory; copy other files as-is."""
7091 dst .mkdir (parents = True , exist_ok = True )
@@ -82,16 +103,45 @@ def convert_dir_of_safetensors(src: Path, dst: Path) -> None:
82103
83104def main ():
84105 ap = argparse .ArgumentParser (
85- description = "Convert FP32 tensors to BF16 (skip optimizer states)."
106+ description = "Convert FP32 tensors to BF16 (skips optimizer states)."
86107 )
87108 ap .add_argument (
88109 "input" ,
89110 type = Path ,
90111 help = "Input: .pt/.pth, .safetensors, or HF directory with .safetensors" ,
91112 )
92113 ap .add_argument ("output" , type = Path , help = "Output file or directory" )
114+
115+ ap .add_argument (
116+ "--slim" ,
117+ action = "store_true" ,
118+ help = "For directory inputs: after conversion, copy everything \
119+ except files listed in --skip (default: optimizer.pt)." ,
120+ )
121+ ap .add_argument (
122+ "--slim-only" ,
123+ action = "store_true" ,
124+ help = "For directory inputs: DO NOT convert; just copy everything \
125+ except files in --skip." ,
126+ )
127+ ap .add_argument (
128+ "--skip" ,
129+ default = "optimizer.pt" ,
130+ help = "Comma-separated file names to skip during slimming \
131+ (applies to --slim or --slim-only). Default: optimizer.pt" ,
132+ )
133+
93134 args = ap .parse_args ()
94135
136+ if args .slim and args .slim_only :
137+ raise SystemExit ("Choose at most one of: --slim or --slim-only." )
138+
139+ skip_list = (
140+ [s .strip () for s in args .skip .split ("," )]
141+ if (args .slim or args .slim_only )
142+ else []
143+ )
144+
95145 p = args .input
96146 if p .is_file ():
97147 sfx = p .suffix .lower ()
@@ -111,10 +161,19 @@ def main():
111161 else :
112162 raise SystemExit (f"Unsupported file type: { p } " )
113163 elif p .is_dir ():
114- if any (x .suffix == ".safetensors" for x in p .iterdir ()):
115- convert_dir_of_safetensors (p , args .output )
116- else :
164+ if not any (x .suffix == ".safetensors" for x in p .iterdir ()):
117165 raise SystemExit ("Directory has no .safetensors files." )
166+ if args .slim_only :
167+ slim_copy_dir_skip_only (p , args .output , skip_list )
168+ else :
169+ convert_dir_of_safetensors (p , args .output )
170+ if args .slim :
171+ tmp = args .output .parent / (args .output .name + "_tmp_slim" )
172+ if tmp .exists ():
173+ shutil .rmtree (tmp )
174+ slim_copy_dir_skip_only (args .output , tmp , skip_list )
175+ shutil .rmtree (args .output )
176+ tmp .rename (args .output )
118177 else :
119178 raise SystemExit (f"Not found: { p } " )
120179
0 commit comments