File tree Expand file tree Collapse file tree 1 file changed +17
-3
lines changed
Expand file tree Collapse file tree 1 file changed +17
-3
lines changed Original file line number Diff line number Diff line change 5959}
6060
6161
62+ def _get_model_variant (name_or_path : str ) -> str | None :
63+ """Extract model variant for alignment heads lookup."""
64+ if name_or_path in _ALIGNMENT_HEADS :
65+ return name_or_path
66+
67+ # Extract from repo name like "openai/whisper-large-v3"
68+ name = name_or_path .split ("/" )[- 1 ]
69+ if name .startswith ("whisper-" ):
70+ return name [8 :] # Remove "whisper-" prefix
71+
72+ return None
73+
74+
6275def _download (url : str , root : str ) -> str :
6376 os .makedirs (root , exist_ok = True )
6477
@@ -156,10 +169,11 @@ def load_torch_weights_and_config(
156169 if download_root is None :
157170 download_root = os .path .join (os .path .expanduser ("~" ), ".cache/whisper" )
158171
159- # todo: accept alignment_heads of local Pytorch checkpoint
160- alignment_heads = None
172+ # Look up alignment heads using normalized variant name
173+ variant = _get_model_variant (name_or_path )
174+ alignment_heads = _ALIGNMENT_HEADS .get (variant ) if variant else None
175+
161176 if name_or_path in _MODELS :
162- alignment_heads = _ALIGNMENT_HEADS [name_or_path ]
163177 name_or_path = _download (_MODELS [name_or_path ], download_root )
164178 elif not Path (name_or_path ).exists ():
165179 # Try downloading from HF
You can’t perform that action at this time.
0 commit comments