Skip to content

Commit 0764a0e

Browse files
committed
Use optimal alignment heads for Whisper
1 parent e52c128 commit 0764a0e

File tree

1 file changed

+17
-3
lines changed

1 file changed

+17
-3
lines changed

whisper/convert.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,19 @@
5959
}
6060

6161

62+
def _get_model_variant(name_or_path: str) -> str | None:
63+
"""Extract model variant for alignment heads lookup."""
64+
if name_or_path in _ALIGNMENT_HEADS:
65+
return name_or_path
66+
67+
# Extract from repo name like "openai/whisper-large-v3"
68+
name = name_or_path.split("/")[-1]
69+
if name.startswith("whisper-"):
70+
return name[8:] # Remove "whisper-" prefix
71+
72+
return None
73+
74+
6275
def _download(url: str, root: str) -> str:
6376
os.makedirs(root, exist_ok=True)
6477

@@ -156,10 +169,11 @@ def load_torch_weights_and_config(
156169
if download_root is None:
157170
download_root = os.path.join(os.path.expanduser("~"), ".cache/whisper")
158171

159-
# todo: accept alignment_heads of local Pytorch checkpoint
160-
alignment_heads = None
172+
# Look up alignment heads using normalized variant name
173+
variant = _get_model_variant(name_or_path)
174+
alignment_heads = _ALIGNMENT_HEADS.get(variant) if variant else None
175+
161176
if name_or_path in _MODELS:
162-
alignment_heads = _ALIGNMENT_HEADS[name_or_path]
163177
name_or_path = _download(_MODELS[name_or_path], download_root)
164178
elif not Path(name_or_path).exists():
165179
# Try downloading from HF

0 commit comments

Comments
 (0)