Skip to content

Commit 42a0f41

Browse files
committed
Pre-commit fixes
1 parent 36d3c09 commit 42a0f41

File tree

3 files changed

+164
-140
lines changed

3 files changed

+164
-140
lines changed

src/hf_mem/cli.py

Lines changed: 44 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@
1212

1313
import httpx
1414

15+
from hf_mem.gguf import GGUFDtype, GGUFMetadata, fetch_gguf_with_semaphore, gguf_metadata_to_json, merge_shards
1516
from hf_mem.metadata import parse_safetensors_metadata
1617
from hf_mem.print import print_report, print_report_for_gguf
1718
from hf_mem.types import TorchDtypes, get_safetensors_dtype_bytes, torch_dtype_to_safetensors_dtype
18-
from hf_mem.gguf import fetch_gguf_with_semaphore, gguf_metadata_to_json, merge_shards, GGUFMetadata, GGUFDtype
1919

2020
# NOTE: Defines the bytes that will be fetched per safetensors file, but the metadata
2121
# can indeed be larger than that
@@ -125,39 +125,48 @@ async def run(
125125
url = f"https://huggingface.co/api/models/{model_id}/tree/{revision}?recursive=true"
126126
files = await get_json_file(client=client, url=url, headers=headers)
127127
file_paths = [f["path"] for f in files if f.get("path") and f.get("type") == "file"]
128-
129128

130129
# NOTE: GGUF support only applies if:
131130
# 1. The `--gguf-file` flag is set.
132131
# 2. No Safetensors files are found and at least one gguf file is found
133132
gguf_paths = [f for f in file_paths if str(f).endswith(".gguf")]
134-
has_safetensors = any(f in ["model.safetensors", "model.safetensors.index.json", "model_index.json"] for f in file_paths)
133+
has_safetensors = any(
134+
f in ["model.safetensors", "model.safetensors.index.json", "model_index.json"] for f in file_paths
135+
)
135136
gguf = gguf_file is not None or (gguf_paths and not has_safetensors)
136-
137+
137138
if not gguf and (has_safetensors and gguf_paths):
138139
warnings.warn(
139140
f"Both Safetensors and GGUF files have been found for {model_id} @ {revision}, if you want to estimate any of the GGUF file sizes, please use the `--gguf-file` flag with the path to the specific GGUF file. GGUF files found: {gguf_paths}."
140141
)
141142

142143
if gguf:
143144
if kv_cache_dtype not in GGUFDtype.__members__ and kv_cache_dtype != "auto":
144-
raise RuntimeError(f"--kv-cache-dtype={kv_cache_dtype} not recognized for GGUF files. Valid options: {list(GGUFDtype.__members__.keys())} or `auto`.")
145-
145+
raise RuntimeError(
146+
f"--kv-cache-dtype={kv_cache_dtype} not recognized for GGUF files. Valid options: {list(GGUFDtype.__members__.keys())} or `auto`."
147+
)
148+
146149
if not gguf_paths:
147150
raise RuntimeError(f"No GGUF files found for {model_id} @ {revision}.")
148-
151+
149152
if gguf_file:
150153
# Check if it's a sharded file (model-00001-of-00046.gguf)
151-
if prefix_match := re.match(r'(.+)-\d+-of-\d+\.gguf$', gguf_file):
154+
if prefix_match := re.match(r"(.+)-\d+-of-\d+\.gguf$", gguf_file):
152155
# Keep all shards with the same prefix
153156
prefix = prefix_match.group(1)
154-
gguf_paths = [path for path in gguf_paths if re.match(rf'{re.escape(prefix)}-\d+-of-\d+\.gguf$', str(path))]
157+
gguf_paths = [
158+
path
159+
for path in gguf_paths
160+
if re.match(rf"{re.escape(prefix)}-\d+-of-\d+\.gguf$", str(path))
161+
]
155162
else:
156163
# Not sharded
157164
gguf_paths = [path for path in gguf_paths if str(path).endswith(gguf_file)]
158165
if len(gguf_paths) > 1:
159-
raise RuntimeError(f"Multiple GGUF files named `{gguf_file}` found for {model_id} @ {revision}.")
160-
166+
raise RuntimeError(
167+
f"Multiple GGUF files named `{gguf_file}` found for {model_id} @ {revision}."
168+
)
169+
161170
if not gguf_paths:
162171
raise RuntimeError(f"No GGUF file matching `{gguf_file}` found for {model_id} @ {revision}.")
163172

@@ -166,13 +175,15 @@ async def run(
166175
tasks = []
167176
for path in gguf_paths:
168177
# In sharded GGUF files tensor metadata also gets sharded, so we need to merge them all
169-
shard_pattern = re.match(r'(.+)-(\d+)-of-(\d+)\.gguf$', str(path)) # Ex: Kimi-K2.5-BF16-00001-of-00046.gguf
178+
shard_pattern = re.match(
179+
r"(.+)-(\d+)-of-(\d+)\.gguf$", str(path)
180+
) # Ex: Kimi-K2.5-BF16-00001-of-00046.gguf
170181
parse_kv_cache = experimental
171182
# For sharded files, parsing kv_cache data might result in runtime errors (missing fields)
172183
if experimental and shard_pattern:
173-
shard_num = int(shard_pattern.group(2)) # Get first number
174-
parse_kv_cache = (shard_num == 1)
175-
184+
shard_num = int(shard_pattern.group(2)) # Get first number
185+
parse_kv_cache = shard_num == 1
186+
176187
task = asyncio.create_task(
177188
fetch_gguf_with_semaphore(
178189
semaphore=semaphore,
@@ -203,19 +214,16 @@ async def run(
203214
else:
204215
gguf_files[base_name] = metadata
205216
else:
206-
gguf_files[path] = metadata
207-
208-
217+
gguf_files[path] = metadata
218+
209219
if json_output:
210220
print(
211-
json.dumps([
212-
gguf_metadata_to_json(
213-
model_id=filename,
214-
revision=revision,
215-
metadata=gguf_metadata
216-
)
217-
for filename, gguf_metadata in gguf_files.items()
218-
])
221+
json.dumps(
222+
[
223+
gguf_metadata_to_json(model_id=filename, revision=revision, metadata=gguf_metadata)
224+
for filename, gguf_metadata in gguf_files.items()
225+
]
226+
)
219227
)
220228
else:
221229
if gguf_file:
@@ -246,11 +254,11 @@ async def run(
246254
else:
247255
# For multiple files, we use the new one
248256
print_report_for_gguf(
249-
model_id=model_id,
250-
revision=revision,
251-
gguf_files=gguf_files,
252-
ignore_table_width=ignore_table_width
253-
)
257+
model_id=model_id,
258+
revision=revision,
259+
gguf_files=gguf_files,
260+
ignore_table_width=ignore_table_width,
261+
)
254262
return
255263
elif "model.safetensors" in file_paths:
256264
url = f"https://huggingface.co/{model_id}/resolve/{revision}/model.safetensors"
@@ -581,9 +589,11 @@ def main() -> None:
581589
warnings.warn(
582590
"`--experimental` is set, which means that models with an architecture as `...ForCausalLM` and `...ForConditionalGeneration` will include estimations for the KV Cache as well. You can also provide the args `--max-model-len` and `--batch-size` as part of the estimation. Note that enabling `--experimental` means that the output will be different both when displayed and when dumped as JSON with `--json-output`, so bear that in mind."
583591
)
584-
592+
585593
if args.kv_cache_dtype not in KV_CACHE_DTYPE_CHOICES:
586-
raise RuntimeError(f"--kv-cache-dtype={args.kv_cache_dtype} not recognized. Valid options: {KV_CACHE_DTYPE_CHOICES}.")
594+
raise RuntimeError(
595+
f"--kv-cache-dtype={args.kv_cache_dtype} not recognized. Valid options: {KV_CACHE_DTYPE_CHOICES}."
596+
)
587597

588598
asyncio.run(
589599
run(
@@ -598,6 +608,6 @@ def main() -> None:
598608
json_output=args.json_output,
599609
ignore_table_width=args.ignore_table_width,
600610
# NOTE: GGUF flags
601-
gguf_file=args.gguf_file
611+
gguf_file=args.gguf_file,
602612
)
603613
)

0 commit comments

Comments
 (0)