11import asyncio
22import base64
33import os
4- from io import BytesIO
54import time
5+ from io import BytesIO
66from typing import List , Tuple
77
8+ import aiohttp
89import numpy as np
910import requests
1011from accelerate import Accelerator , DistributedType
1112from loguru import logger as eval_logger
1213from scipy .io import wavfile
13- import aiohttp
1414from tqdm import tqdm
1515from transformers import AutoProcessor
1616
2828class WhisperTT (lmms ):
2929 """
3030 Whisper Audio Model - HTTP API Client
31-
31+
3232 This implementation uses HTTP calls to the tt-media-server instead of
3333 direct ttnn/tt-metal execution, allowing evals to run outside docker.
3434 """
@@ -52,23 +52,23 @@ def __init__(
5252 # Log warning for unexpected kwargs but don't fail
5353 if kwargs :
5454 eval_logger .warning (f"Ignoring unexpected kwargs: { kwargs } " )
55-
55+
5656 # Get base URL from env var or argument
5757 self .base_url = base_url or os .getenv ("OPENAI_API_BASE" , "http://127.0.0.1:8000" )
5858 self .timeout = timeout
5959 self .max_retries = max_retries
6060 self .pretrained = pretrained
61-
61+
6262 # Get API key from environment
6363 self .api_key = os .getenv ("OPENAI_API_KEY" , "your-secret-key" )
64-
64+
6565 eval_logger .info (f"Initializing WhisperTT HTTP client with base_url: { self .base_url } " )
66-
66+
6767 # Setup processor for tokenization
6868 self .processor = AutoProcessor .from_pretrained (pretrained )
6969 self .processor .tokenizer .set_prefix_tokens (language = language , task = task )
7070 self ._tokenizer = self .processor .tokenizer
71-
71+
7272 # Setup accelerator for distributed evaluation
7373 accelerator = Accelerator ()
7474 if accelerator .num_processes > 1 :
@@ -79,7 +79,7 @@ def __init__(
7979 self ._device = device
8080 self ._rank = 0
8181 self ._world_size = 1
82-
82+
8383 self .batch_size_per_gpu = int (batch_size )
8484 self .use_cache = use_cache
8585
@@ -110,41 +110,41 @@ def world_size(self):
110110 def encode_audio_to_base64_wav (self , audio_array : np .ndarray , sampling_rate : int ) -> str :
111111 """
112112 Convert audio numpy array to base64-encoded WAV format.
113-
113+
114114 Args:
115115 audio_array: Audio data as numpy array
116116 sampling_rate: Sampling rate of the audio
117-
117+
118118 Returns:
119119 Base64-encoded WAV file string
120120 """
121121 # Ensure float32 to create 32-bit WAV files (not 64-bit)
122122 # This prevents "Unsupported bit depth: 64" errors on the server
123123 audio_array = audio_array .astype (np .float32 )
124-
124+
125125 # Create WAV file in memory
126126 wav_buffer = BytesIO ()
127127 wavfile .write (wav_buffer , sampling_rate , audio_array )
128128 wav_bytes = wav_buffer .getvalue ()
129-
129+
130130 # Encode to base64
131- base64_str = base64 .b64encode (wav_bytes ).decode (' utf-8' )
131+ base64_str = base64 .b64encode (wav_bytes ).decode (" utf-8" )
132132 return base64_str
133133
134134 def transcribe_audio (self , audio_array : np .ndarray , sampling_rate : int ) -> str :
135135 """
136136 Transcribe audio using the tt-media-server HTTP API.
137-
137+
138138 Args:
139139 audio_array: Audio data as numpy array
140140 sampling_rate: Sampling rate of the audio
141-
141+
142142 Returns:
143143 Transcription text
144144 """
145145 # Encode audio to base64 WAV
146146 base64_audio = self .encode_audio_to_base64_wav (audio_array , sampling_rate )
147-
147+
148148 # Prepare request
149149 url = f"{ self .base_url } /audio/transcriptions"
150150 headers = {
@@ -153,58 +153,50 @@ def transcribe_audio(self, audio_array: np.ndarray, sampling_rate: int) -> str:
153153 }
154154 if self .api_key :
155155 headers ["Authorization" ] = f"Bearer { self .api_key } "
156-
157- payload = {
158- "file" : base64_audio ,
159- "stream" : False
160- }
161-
156+
157+ payload = {"file" : base64_audio , "stream" : False }
158+
162159 # Make request with retries
163160 for attempt in range (self .max_retries ):
164161 try :
165- response = requests .post (
166- url ,
167- json = payload ,
168- headers = headers ,
169- timeout = self .timeout
170- )
162+ response = requests .post (url , json = payload , headers = headers , timeout = self .timeout )
171163 response .raise_for_status ()
172-
164+
173165 # Parse response
174166 result = response .json ()
175-
167+
176168 # Extract transcription text from response
177169 # The response format should contain the transcription
178170 if isinstance (result , dict ):
179171 # Try common keys for transcription text
180- transcription = result .get (' text' ) or result .get (' transcription' ) or result .get (' result' )
172+ transcription = result .get (" text" ) or result .get (" transcription" ) or result .get (" result" )
181173 if transcription :
182174 return transcription
183175 # If no known key, return the entire dict as string
184176 eval_logger .warning (f"Unexpected response format: { result } " )
185177 return str (result )
186178 else :
187179 return str (result )
188-
180+
189181 except requests .exceptions .RequestException as e :
190182 if attempt < self .max_retries - 1 :
191183 eval_logger .warning (f"Request failed (attempt { attempt + 1 } /{ self .max_retries } ): { e } " )
192184 continue
193185 else :
194186 eval_logger .error (f"All retry attempts failed: { e } " )
195187 raise
196-
188+
197189 return ""
198190
199191 async def _generate_audio_transcription (self , session , audio_array : np .ndarray , sampling_rate : int , audio_index : int = None ) -> str :
200192 """
201193 Transcribe audio using the tt-media-server HTTP API.
202-
194+
203195 Args:
204196 audio_array: Audio data as numpy array
205197 sampling_rate: Sampling rate of the audio
206198 audio_index: Index of the audio for logging purposes
207-
199+
208200 Returns:
209201 Transcription text
210202 """
@@ -216,38 +208,27 @@ async def _generate_audio_transcription(self, session, audio_array: np.ndarray,
216208
217209 # Prepare request
218210 url = f"{ self .base_url } /audio/transcriptions"
219- headers = {
220- "accept" : "application/json" ,
221- "Content-Type" : "application/json"
222- }
211+ headers = {"accept" : "application/json" , "Content-Type" : "application/json" }
223212 if self .api_key :
224213 headers ["Authorization" ] = f"Bearer { self .api_key } "
225-
226- payload = {
227- "file" : base64_audio ,
228- "stream" : False
229- }
214+
215+ payload = {"file" : base64_audio , "stream" : False }
230216
231217 try :
232- async with session .post (
233- f"{ self .base_url } /audio/transcriptions" ,
234- json = payload ,
235- headers = headers ,
236- timeout = aiohttp .ClientTimeout (total = 15000 )
237- ) as response :
218+ async with session .post (f"{ self .base_url } /audio/transcriptions" , json = payload , headers = headers , timeout = aiohttp .ClientTimeout (total = 15000 )) as response :
238219 elapsed = time .time () - start_time
239220
240221 if response .status != 200 :
241222 eval_logger .info (f"❌ Audio transcription failed with status: { response .status } " )
242223 return ""
243224
244225 result = await response .json ()
245-
226+
246227 # Extract transcription text from response
247228 # The response format should contain the transcription
248229 if isinstance (result , dict ):
249230 # Try common keys for transcription text
250- transcription = result .get (' text' ) or result .get (' transcription' ) or result .get (' result' )
231+ transcription = result .get (" text" ) or result .get (" transcription" ) or result .get (" result" )
251232 eval_logger .info (f"Transcription result for audio { audio_index } : { transcription } " )
252233 if transcription :
253234 return transcription
@@ -282,18 +263,18 @@ def _collate(x):
282263 return - len (toks ), x [0 ]
283264
284265 pbar = tqdm (total = len (requests ), disable = (self .rank != 0 ), desc = "Model Responding" )
285-
266+
286267 # Group requests by their generation_kwargs
287268 re_ords = utils .Collator ([reg .args for reg in requests ], _collate , grouping = True )
288269 chunks = re_ords .get_batched (n = self .batch_size , batch_fn = None )
289-
270+
290271 # Collect all audios from all chunks first
291272 all_audios = []
292273 all_contexts = []
293274 all_gen_kwargs_list = []
294-
275+
295276 time_start = time .time ()
296-
277+
297278 for chunk in chunks :
298279 contexts , all_gen_kwargs , doc_to_visual , doc_id , task , split = zip (* chunk )
299280 task = task [0 ]
@@ -318,12 +299,12 @@ def _collate(x):
318299 sampling_rate = self .processor .feature_extractor .sampling_rate
319300 assert sampling_rate == SAMPLING_RATE , f"Expected sampling rate { SAMPLING_RATE } , but got { sampling_rate } "
320301 audios = [downsample_audio (audio ["array" ], audio ["sampling_rate" ], sampling_rate ) for audio in flattened_audios ]
321-
302+
322303 # Collect all data
323304 all_audios .extend (audios )
324305 all_contexts .extend (contexts )
325306 all_gen_kwargs_list .extend ([gen_kwargs ] * len (contexts ))
326-
307+
327308 time_end_prep = time .time ()
328309 eval_logger .info (f"Preparation time for { len (all_audios )} requests: { time_end_prep - time_start :.2f} s" )
329310
@@ -334,7 +315,7 @@ async def run_transcriptions():
334315 return await asyncio .gather (* tasks )
335316
336317 answers = asyncio .run (run_transcriptions ())
337-
318+
338319 time_end_process = time .time ()
339320
340321 eval_logger .info (f"Total time for { len (all_audios )} requests across all chunks { time_end_process - time_start :.2f} s" )
@@ -348,11 +329,11 @@ async def run_transcriptions():
348329 until = gen_kwargs ["until" ]
349330 if isinstance (until , str ):
350331 until = [until ]
351-
332+
352333 for term in until :
353334 if len (term ) > 0 :
354335 ans = ans .split (term )[0 ]
355-
336+
356337 processed_answers .append (ans )
357338
358339 for ans , context , gen_kwargs in zip (processed_answers , all_contexts , all_gen_kwargs_list ):
@@ -364,9 +345,9 @@ async def run_transcriptions():
364345 res = re_ords .get_original (res )
365346
366347 pbar .close ()
367-
348+
368349 time_end_process = time .time ()
369-
350+
370351 eval_logger .info (f"Total time for { len (all_audios )} requests across all chunks { time_end_process - time_start :.2f} s" )
371352
372353 return res
0 commit comments