77import mlx .nn as nn
88
99from ..base import STTOutput
10- from ..wav2vec .wav2vec import (
11- ModelConfig ,
12- Wav2Vec2Model ,
13- )
10+ from ..wav2vec .wav2vec import ModelConfig , Wav2Vec2Model
1411
1512
1613class Model (nn .Module ):
@@ -74,6 +71,7 @@ def generate(
7471
7572 if isinstance (audio , (str , Path )):
7673 from mlx_audio .stt .utils import load_audio
74+
7775 audio = load_audio (str (audio ), sr = self .sample_rate , dtype = dtype )
7876 elif not isinstance (audio , mx .array ):
7977 audio = mx .array (audio )
@@ -84,7 +82,9 @@ def generate(
8482 if audio .dtype != dtype :
8583 audio = audio .astype (dtype )
8684
87- audio = (audio - mx .mean (audio , axis = - 1 , keepdims = True )) / (mx .std (audio , axis = - 1 , keepdims = True ) + 1e-7 )
85+ audio = (audio - mx .mean (audio , axis = - 1 , keepdims = True )) / (
86+ mx .std (audio , axis = - 1 , keepdims = True ) + 1e-7
87+ )
8888
8989 logits = self (audio )
9090 mx .eval (logits )
@@ -147,13 +147,16 @@ def post_load_hook(cls, model: "Model", model_path: Path) -> "Model":
147147 with open (vocab_path ) as f :
148148 vocab = json .load (f )
149149 if isinstance (next (iter (vocab .values ())), dict ):
150- lang_vocab = vocab .get ("eng" , vocab .get ("en" , next (iter (vocab .values ()))))
150+ lang_vocab = vocab .get (
151+ "eng" , vocab .get ("en" , next (iter (vocab .values ())))
152+ )
151153 model ._vocab = {v : k for k , v in lang_vocab .items ()}
152154 else :
153155 model ._vocab = {v : k for k , v in vocab .items ()}
154156
155157 try :
156158 from transformers import AutoProcessor
159+
157160 model ._processor = AutoProcessor .from_pretrained (str (model_path ))
158161 except Exception :
159162 model ._processor = None
0 commit comments