Skip to content

Commit e52c128

Browse files
Use model.safetensors with Whisper (#1399)
1 parent 7ddca42 commit e52c128

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

whisper/convert.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,7 @@ def quantize(weights, config, args):
382382

383383
# Save weights
384384
print("[INFO] Saving")
385-
mx.save_safetensors(str(mlx_path / "weights.safetensors"), weights)
385+
mx.save_safetensors(str(mlx_path / "model.safetensors"), weights)
386386

387387
# Save config.json with model_type
388388
with open(str(mlx_path / "config.json"), "w") as f:

whisper/mlx_whisper/load_models.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,10 @@ def load_model(
2626

2727
model_args = whisper.ModelDimensions(**config)
2828

29-
wf = model_path / "weights.safetensors"
29+
# Prefer model.safetensors, fall back to weights.safetensors, then weights.npz
30+
wf = model_path / "model.safetensors"
31+
if not wf.exists():
32+
wf = model_path / "weights.safetensors"
3033
if not wf.exists():
3134
wf = model_path / "weights.npz"
3235
weights = mx.load(str(wf))

0 commit comments

Comments
 (0)