1212
1313import httpx
1414
15+ from hf_mem .gguf import GGUFDtype , GGUFMetadata , fetch_gguf_with_semaphore , gguf_metadata_to_json , merge_shards
1516from hf_mem .metadata import parse_safetensors_metadata
1617from hf_mem .print import print_report , print_report_for_gguf
1718from hf_mem .types import TorchDtypes , get_safetensors_dtype_bytes , torch_dtype_to_safetensors_dtype
18- from hf_mem .gguf import fetch_gguf_with_semaphore , gguf_metadata_to_json , merge_shards , GGUFMetadata , GGUFDtype
1919
2020# NOTE: Defines the bytes that will be fetched per safetensors file, but the metadata
2121# can indeed be larger than that
@@ -125,39 +125,48 @@ async def run(
125125 url = f"https://huggingface.co/api/models/{ model_id } /tree/{ revision } ?recursive=true"
126126 files = await get_json_file (client = client , url = url , headers = headers )
127127 file_paths = [f ["path" ] for f in files if f .get ("path" ) and f .get ("type" ) == "file" ]
128-
129128
130129 # NOTE: GGUF support only applies if:
131130 # 1. The `--gguf-file` flag is set.
132131 # 2. No Safetensors files are found and at least one gguf file is found
133132 gguf_paths = [f for f in file_paths if str (f ).endswith (".gguf" )]
134- has_safetensors = any (f in ["model.safetensors" , "model.safetensors.index.json" , "model_index.json" ] for f in file_paths )
133+ has_safetensors = any (
134+ f in ["model.safetensors" , "model.safetensors.index.json" , "model_index.json" ] for f in file_paths
135+ )
135136 gguf = gguf_file is not None or (gguf_paths and not has_safetensors )
136-
137+
137138 if not gguf and (has_safetensors and gguf_paths ):
138139 warnings .warn (
139140 f"Both Safetensors and GGUF files have been found for { model_id } @ { revision } , if you want to estimate any of the GGUF file sizes, please use the `--gguf-file` flag with the path to the specific GGUF file. GGUF files found: { gguf_paths } ."
140141 )
141142
142143 if gguf :
143144 if kv_cache_dtype not in GGUFDtype .__members__ and kv_cache_dtype != "auto" :
144- raise RuntimeError (f"--kv-cache-dtype={ kv_cache_dtype } not recognized for GGUF files. Valid options: { list (GGUFDtype .__members__ .keys ())} or `auto`." )
145-
145+ raise RuntimeError (
146+ f"--kv-cache-dtype={ kv_cache_dtype } not recognized for GGUF files. Valid options: { list (GGUFDtype .__members__ .keys ())} or `auto`."
147+ )
148+
146149 if not gguf_paths :
147150 raise RuntimeError (f"No GGUF files found for { model_id } @ { revision } ." )
148-
151+
149152 if gguf_file :
150153 # Check if it's a sharded file (model-00001-of-00046.gguf)
151- if prefix_match := re .match (r' (.+)-\d+-of-\d+\.gguf$' , gguf_file ):
154+ if prefix_match := re .match (r" (.+)-\d+-of-\d+\.gguf$" , gguf_file ):
152155 # Keep all shards with the same prefix
153156 prefix = prefix_match .group (1 )
154- gguf_paths = [path for path in gguf_paths if re .match (rf'{ re .escape (prefix )} -\d+-of-\d+\.gguf$' , str (path ))]
157+ gguf_paths = [
158+ path
159+ for path in gguf_paths
160+ if re .match (rf"{ re .escape (prefix )} -\d+-of-\d+\.gguf$" , str (path ))
161+ ]
155162 else :
156163 # Not sharded
157164 gguf_paths = [path for path in gguf_paths if str (path ).endswith (gguf_file )]
158165 if len (gguf_paths ) > 1 :
159- raise RuntimeError (f"Multiple GGUF files named `{ gguf_file } ` found for { model_id } @ { revision } ." )
160-
166+ raise RuntimeError (
167+ f"Multiple GGUF files named `{ gguf_file } ` found for { model_id } @ { revision } ."
168+ )
169+
161170 if not gguf_paths :
162171 raise RuntimeError (f"No GGUF file matching `{ gguf_file } ` found for { model_id } @ { revision } ." )
163172
@@ -166,13 +175,15 @@ async def run(
166175 tasks = []
167176 for path in gguf_paths :
168177 # In sharded GGUF files tensor metadata also gets sharded, so we need to merge them all
169- shard_pattern = re .match (r'(.+)-(\d+)-of-(\d+)\.gguf$' , str (path )) # Ex: Kimi-K2.5-BF16-00001-of-00046.gguf
178+ shard_pattern = re .match (
179+ r"(.+)-(\d+)-of-(\d+)\.gguf$" , str (path )
180+ ) # Ex: Kimi-K2.5-BF16-00001-of-00046.gguf
170181 parse_kv_cache = experimental
171182 # For sharded files, parsing kv_cache data might result in runtime errors (missing fields)
172183 if experimental and shard_pattern :
173- shard_num = int (shard_pattern .group (2 )) # Get first number
174- parse_kv_cache = ( shard_num == 1 )
175-
184+ shard_num = int (shard_pattern .group (2 )) # Get first number
185+ parse_kv_cache = shard_num == 1
186+
176187 task = asyncio .create_task (
177188 fetch_gguf_with_semaphore (
178189 semaphore = semaphore ,
@@ -203,19 +214,16 @@ async def run(
203214 else :
204215 gguf_files [base_name ] = metadata
205216 else :
206- gguf_files [path ] = metadata
207-
208-
217+ gguf_files [path ] = metadata
218+
209219 if json_output :
210220 print (
211- json .dumps ([
212- gguf_metadata_to_json (
213- model_id = filename ,
214- revision = revision ,
215- metadata = gguf_metadata
216- )
217- for filename , gguf_metadata in gguf_files .items ()
218- ])
221+ json .dumps (
222+ [
223+ gguf_metadata_to_json (model_id = filename , revision = revision , metadata = gguf_metadata )
224+ for filename , gguf_metadata in gguf_files .items ()
225+ ]
226+ )
219227 )
220228 else :
221229 if gguf_file :
@@ -246,11 +254,11 @@ async def run(
246254 else :
247255 # For multiple files, we use the new one
248256 print_report_for_gguf (
249- model_id = model_id ,
250- revision = revision ,
251- gguf_files = gguf_files ,
252- ignore_table_width = ignore_table_width
253- )
257+ model_id = model_id ,
258+ revision = revision ,
259+ gguf_files = gguf_files ,
260+ ignore_table_width = ignore_table_width ,
261+ )
254262 return
255263 elif "model.safetensors" in file_paths :
256264 url = f"https://huggingface.co/{ model_id } /resolve/{ revision } /model.safetensors"
@@ -581,9 +589,11 @@ def main() -> None:
581589 warnings .warn (
582590 "`--experimental` is set, which means that models with an architecture as `...ForCausalLM` and `...ForConditionalGeneration` will include estimations for the KV Cache as well. You can also provide the args `--max-model-len` and `--batch-size` as part of the estimation. Note that enabling `--experimental` means that the output will be different both when displayed and when dumped as JSON with `--json-output`, so bear that in mind."
583591 )
584-
592+
585593 if args .kv_cache_dtype not in KV_CACHE_DTYPE_CHOICES :
586- raise RuntimeError (f"--kv-cache-dtype={ args .kv_cache_dtype } not recognized. Valid options: { KV_CACHE_DTYPE_CHOICES } ." )
594+ raise RuntimeError (
595+ f"--kv-cache-dtype={ args .kv_cache_dtype } not recognized. Valid options: { KV_CACHE_DTYPE_CHOICES } ."
596+ )
587597
588598 asyncio .run (
589599 run (
@@ -598,6 +608,6 @@ def main() -> None:
598608 json_output = args .json_output ,
599609 ignore_table_width = args .ignore_table_width ,
600610 # NOTE: GGUF flags
601- gguf_file = args .gguf_file
611+ gguf_file = args .gguf_file ,
602612 )
603613 )
0 commit comments