44
55# Standard
66from pathlib import Path
7- from typing import Any
7+ from typing import Any , Iterable , Set
88import argparse
99import shutil
1010
@@ -29,7 +29,9 @@ def cast_fp32_to_bf16(x: Any, *, in_optim: bool = False) -> Any:
2929 out = {}
3030 for k , v in x .items ():
3131 k_lower = k .lower () if isinstance (k , str ) else ""
32- child_in_optim = in_optim or (k_lower in OPTIM_ROOT_KEYS )
32+ child_in_optim = in_optim or any (
33+ k_lower .startswith (root ) for root in OPTIM_ROOT_KEYS
34+ )
3335 out [k ] = cast_fp32_to_bf16 (v , in_optim = child_in_optim )
3436 return out
3537 if isinstance (x , (list , tuple )):
@@ -46,8 +48,8 @@ def convert_pt_pth(inp: Path, out: Path) -> None:
4648
4749
4850def is_optim_tensor_name (name : str ) -> bool :
49- parts = (name or "" ).lower ().replace ("/" , "." ).split ("." )
50- return bool ( parts ) and parts [ 0 ] in OPTIM_ROOT_KEYS
51+ first = (name or "" ).lower ().replace ("/" , "." ).split ("." )[ 0 ]
52+ return any ( first . startswith ( root ) for root in OPTIM_ROOT_KEYS )
5153
5254
5355def convert_safetensors_file (inp : Path , out : Path ) -> None :
@@ -65,6 +67,27 @@ def convert_safetensors_file(inp: Path, out: Path) -> None:
6567 print (f"[safetensors] wrote: { out } " )
6668
6769
70+ def slim_copy_dir_skip_only (src : Path , dst : Path , skip_names : Iterable [str ]) -> None :
71+ """
72+ Copy everything from src -> dst EXCEPT files whose names are in skip_names.
73+ Directories are copied entirely unless their name is in skip_names.
74+ """
75+ dst .mkdir (parents = True , exist_ok = True )
76+ skip_set : Set [str ] = set (skip_names )
77+
78+ for item in src .iterdir ():
79+ if item .name in skip_set :
80+ continue
81+ target = dst / item .name
82+ if item .is_file ():
83+ shutil .copy2 (item , target )
84+ elif item .is_dir ():
85+ shutil .copytree (item , target , dirs_exist_ok = True )
86+ print (
87+ f"[slim] wrote: { dst } (skipped: { ', ' .join (skip_set ) if skip_set else 'none' } )"
88+ )
89+
90+
6891def convert_dir_of_safetensors (src : Path , dst : Path ) -> None :
6992 """Convert all .safetensors in a directory; copy other files as-is."""
7093 dst .mkdir (parents = True , exist_ok = True )
@@ -82,16 +105,45 @@ def convert_dir_of_safetensors(src: Path, dst: Path) -> None:
82105
83106def main ():
84107 ap = argparse .ArgumentParser (
85- description = "Convert FP32 tensors to BF16 (skip optimizer states)."
108+ description = "Convert FP32 tensors to BF16 (skips optimizer states)."
86109 )
87110 ap .add_argument (
88111 "input" ,
89112 type = Path ,
90113 help = "Input: .pt/.pth, .safetensors, or HF directory with .safetensors" ,
91114 )
92115 ap .add_argument ("output" , type = Path , help = "Output file or directory" )
116+
117+ ap .add_argument (
118+ "--slim" ,
119+ action = "store_true" ,
120+ help = "For directory inputs: after conversion, copy everything \
121+ except files listed in --skip (default: optimizer.pt)." ,
122+ )
123+ ap .add_argument (
124+ "--slim-only" ,
125+ action = "store_true" ,
126+ help = "For directory inputs: DO NOT convert; just copy everything \
127+ except files in --skip." ,
128+ )
129+ ap .add_argument (
130+ "--skip" ,
131+ default = "optimizer.pt" ,
132+ help = "Comma-separated file names to skip during slimming \
133+ (applies to --slim or --slim-only). Default: optimizer.pt" ,
134+ )
135+
93136 args = ap .parse_args ()
94137
138+ if args .slim and args .slim_only :
139+ raise SystemExit ("Choose at most one of: --slim or --slim-only." )
140+
141+ skip_list = (
142+ [s .strip () for s in args .skip .split ("," )]
143+ if (args .slim or args .slim_only )
144+ else []
145+ )
146+
95147 p = args .input
96148 if p .is_file ():
97149 sfx = p .suffix .lower ()
@@ -111,10 +163,19 @@ def main():
111163 else :
112164 raise SystemExit (f"Unsupported file type: { p } " )
113165 elif p .is_dir ():
114- if any (x .suffix == ".safetensors" for x in p .iterdir ()):
115- convert_dir_of_safetensors (p , args .output )
116- else :
166+ if not any (x .suffix == ".safetensors" for x in p .iterdir ()):
117167 raise SystemExit ("Directory has no .safetensors files." )
168+ if args .slim_only :
169+ slim_copy_dir_skip_only (p , args .output , skip_list )
170+ else :
171+ convert_dir_of_safetensors (p , args .output )
172+ if args .slim :
173+ tmp = args .output .parent / (args .output .name + "_tmp_slim" )
174+ if tmp .exists ():
175+ shutil .rmtree (tmp )
176+ slim_copy_dir_skip_only (args .output , tmp , skip_list )
177+ shutil .rmtree (args .output )
178+ tmp .rename (args .output )
118179 else :
119180 raise SystemExit (f"Not found: { p } " )
120181
0 commit comments