Skip to content

Commit 8daf445

Browse files
authored
fix: improve spatialviz utils quality (#961)
* fix: improve spatialviz utils quality - Fix FileExistsError -> FileNotFoundError (correct exception type) - Replace print() with eval_logger for consistent logging - Add type hints to all functions - Fix missing comma bug in final_answer_patterns list - Remove redundant image_path = image_path assignment - Initialize op variable to prevent potential UnboundLocalError - Break long prompt string for readability (88 char line limit) * style: apply black formatting * style: apply black and isort formatting to all files
1 parent c79490b commit 8daf445

File tree

14 files changed

+1300
-372
lines changed

14 files changed

+1300
-372
lines changed

lmms_eval/api/metrics.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -536,32 +536,33 @@ def bootstrap_chair_metric(metric_fn, xs, iters):
536536
print(f"bootstrapping for stddev: {metric_fn.__name__}")
537537
res = []
538538
from tqdm import tqdm
539-
539+
540540
for _ in tqdm(range(iters), desc="Bootstrap"):
541541
bootstrap_sample = random.choices(xs, k=len(xs))
542542
metric_value = metric_fn(bootstrap_sample)
543543
res.append(metric_value)
544-
544+
545545
return sample_stddev(res)
546546

547+
547548
def stderr_for_metric(metric, bootstrap_iters: int):
548549
if bootstrap_iters <= 0:
549550
# return no function (don't compute stderr) if bootstrap iters = 0
550551
return None
551552
# for coco_cap_chair
552-
from lmms_eval.tasks.coco_cap_chair.utils import (
553-
coco_cap_chair_aggregate_results_chair_i,
554-
coco_cap_chair_aggregate_results_chair_s,
555-
coco_cap_chair_aggregate_results_recall,
556-
)
557553
# for amber_g
558554
from lmms_eval.tasks.amber_g.utils import (
559555
amber_g_aggregate_chair,
556+
amber_g_aggregate_cog,
560557
amber_g_aggregate_cover,
561558
amber_g_aggregate_hal,
562-
amber_g_aggregate_cog,
563559
)
564-
560+
from lmms_eval.tasks.coco_cap_chair.utils import (
561+
coco_cap_chair_aggregate_results_chair_i,
562+
coco_cap_chair_aggregate_results_chair_s,
563+
coco_cap_chair_aggregate_results_recall,
564+
)
565+
565566
bootstrappable = [
566567
median,
567568
matthews_corrcoef,
@@ -582,10 +583,10 @@ def stderr_for_metric(metric, bootstrap_iters: int):
582583
if metric in bootstrappable:
583584
return lambda x: bootstrap_stderr(metric, x, iters=bootstrap_iters)
584585

585-
if hasattr(metric, '__name__'):
586-
if 'coco_cap_chair' in metric.__name__:
586+
if hasattr(metric, "__name__"):
587+
if "coco_cap_chair" in metric.__name__:
587588
return lambda x: bootstrap_chair_metric(metric, x, iters=bootstrap_iters)
588-
if 'amber_g' in metric.__name__ or 'amber_' in metric.__name__:
589+
if "amber_g" in metric.__name__ or "amber_" in metric.__name__:
589590
return lambda x: bootstrap_chair_metric(metric, x, iters=bootstrap_iters)
590591

591592
stderr = {mean: mean_stderr, acc_all: acc_all_stderr}

lmms_eval/models/chat/openai_compatible.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def generate_until(self, requests) -> List[str]:
7878
if "o1" in self.model_version or "o3" in self.model_version or "o4" in self.model_version or "gpt-5" in self.model_version:
7979
del payload["temperature"]
8080
payload.pop("max_tokens")
81-
#payload["reasoning_effort"] = "medium"
81+
# payload["reasoning_effort"] = "medium"
8282
payload["response_format"] = {"type": "text"}
8383
payload["max_completion_tokens"] = 5000
8484

lmms_eval/models/whisper_tt.py

Lines changed: 45 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
import asyncio
22
import base64
33
import os
4-
from io import BytesIO
54
import time
5+
from io import BytesIO
66
from typing import List, Tuple
77

8+
import aiohttp
89
import numpy as np
910
import requests
1011
from accelerate import Accelerator, DistributedType
1112
from loguru import logger as eval_logger
1213
from scipy.io import wavfile
13-
import aiohttp
1414
from tqdm import tqdm
1515
from transformers import AutoProcessor
1616

@@ -28,7 +28,7 @@
2828
class WhisperTT(lmms):
2929
"""
3030
Whisper Audio Model - HTTP API Client
31-
31+
3232
This implementation uses HTTP calls to the tt-media-server instead of
3333
direct ttnn/tt-metal execution, allowing evals to run outside docker.
3434
"""
@@ -52,23 +52,23 @@ def __init__(
5252
# Log warning for unexpected kwargs but don't fail
5353
if kwargs:
5454
eval_logger.warning(f"Ignoring unexpected kwargs: {kwargs}")
55-
55+
5656
# Get base URL from env var or argument
5757
self.base_url = base_url or os.getenv("OPENAI_API_BASE", "http://127.0.0.1:8000")
5858
self.timeout = timeout
5959
self.max_retries = max_retries
6060
self.pretrained = pretrained
61-
61+
6262
# Get API key from environment
6363
self.api_key = os.getenv("OPENAI_API_KEY", "your-secret-key")
64-
64+
6565
eval_logger.info(f"Initializing WhisperTT HTTP client with base_url: {self.base_url}")
66-
66+
6767
# Setup processor for tokenization
6868
self.processor = AutoProcessor.from_pretrained(pretrained)
6969
self.processor.tokenizer.set_prefix_tokens(language=language, task=task)
7070
self._tokenizer = self.processor.tokenizer
71-
71+
7272
# Setup accelerator for distributed evaluation
7373
accelerator = Accelerator()
7474
if accelerator.num_processes > 1:
@@ -79,7 +79,7 @@ def __init__(
7979
self._device = device
8080
self._rank = 0
8181
self._world_size = 1
82-
82+
8383
self.batch_size_per_gpu = int(batch_size)
8484
self.use_cache = use_cache
8585

@@ -110,41 +110,41 @@ def world_size(self):
110110
def encode_audio_to_base64_wav(self, audio_array: np.ndarray, sampling_rate: int) -> str:
111111
"""
112112
Convert audio numpy array to base64-encoded WAV format.
113-
113+
114114
Args:
115115
audio_array: Audio data as numpy array
116116
sampling_rate: Sampling rate of the audio
117-
117+
118118
Returns:
119119
Base64-encoded WAV file string
120120
"""
121121
# Ensure float32 to create 32-bit WAV files (not 64-bit)
122122
# This prevents "Unsupported bit depth: 64" errors on the server
123123
audio_array = audio_array.astype(np.float32)
124-
124+
125125
# Create WAV file in memory
126126
wav_buffer = BytesIO()
127127
wavfile.write(wav_buffer, sampling_rate, audio_array)
128128
wav_bytes = wav_buffer.getvalue()
129-
129+
130130
# Encode to base64
131-
base64_str = base64.b64encode(wav_bytes).decode('utf-8')
131+
base64_str = base64.b64encode(wav_bytes).decode("utf-8")
132132
return base64_str
133133

134134
def transcribe_audio(self, audio_array: np.ndarray, sampling_rate: int) -> str:
135135
"""
136136
Transcribe audio using the tt-media-server HTTP API.
137-
137+
138138
Args:
139139
audio_array: Audio data as numpy array
140140
sampling_rate: Sampling rate of the audio
141-
141+
142142
Returns:
143143
Transcription text
144144
"""
145145
# Encode audio to base64 WAV
146146
base64_audio = self.encode_audio_to_base64_wav(audio_array, sampling_rate)
147-
147+
148148
# Prepare request
149149
url = f"{self.base_url}/audio/transcriptions"
150150
headers = {
@@ -153,58 +153,50 @@ def transcribe_audio(self, audio_array: np.ndarray, sampling_rate: int) -> str:
153153
}
154154
if self.api_key:
155155
headers["Authorization"] = f"Bearer {self.api_key}"
156-
157-
payload = {
158-
"file": base64_audio,
159-
"stream": False
160-
}
161-
156+
157+
payload = {"file": base64_audio, "stream": False}
158+
162159
# Make request with retries
163160
for attempt in range(self.max_retries):
164161
try:
165-
response = requests.post(
166-
url,
167-
json=payload,
168-
headers=headers,
169-
timeout=self.timeout
170-
)
162+
response = requests.post(url, json=payload, headers=headers, timeout=self.timeout)
171163
response.raise_for_status()
172-
164+
173165
# Parse response
174166
result = response.json()
175-
167+
176168
# Extract transcription text from response
177169
# The response format should contain the transcription
178170
if isinstance(result, dict):
179171
# Try common keys for transcription text
180-
transcription = result.get('text') or result.get('transcription') or result.get('result')
172+
transcription = result.get("text") or result.get("transcription") or result.get("result")
181173
if transcription:
182174
return transcription
183175
# If no known key, return the entire dict as string
184176
eval_logger.warning(f"Unexpected response format: {result}")
185177
return str(result)
186178
else:
187179
return str(result)
188-
180+
189181
except requests.exceptions.RequestException as e:
190182
if attempt < self.max_retries - 1:
191183
eval_logger.warning(f"Request failed (attempt {attempt + 1}/{self.max_retries}): {e}")
192184
continue
193185
else:
194186
eval_logger.error(f"All retry attempts failed: {e}")
195187
raise
196-
188+
197189
return ""
198190

199191
async def _generate_audio_transcription(self, session, audio_array: np.ndarray, sampling_rate: int, audio_index: int = None) -> str:
200192
"""
201193
Transcribe audio using the tt-media-server HTTP API.
202-
194+
203195
Args:
204196
audio_array: Audio data as numpy array
205197
sampling_rate: Sampling rate of the audio
206198
audio_index: Index of the audio for logging purposes
207-
199+
208200
Returns:
209201
Transcription text
210202
"""
@@ -216,38 +208,27 @@ async def _generate_audio_transcription(self, session, audio_array: np.ndarray,
216208

217209
# Prepare request
218210
url = f"{self.base_url}/audio/transcriptions"
219-
headers = {
220-
"accept": "application/json",
221-
"Content-Type": "application/json"
222-
}
211+
headers = {"accept": "application/json", "Content-Type": "application/json"}
223212
if self.api_key:
224213
headers["Authorization"] = f"Bearer {self.api_key}"
225-
226-
payload = {
227-
"file": base64_audio,
228-
"stream": False
229-
}
214+
215+
payload = {"file": base64_audio, "stream": False}
230216

231217
try:
232-
async with session.post(
233-
f"{self.base_url}/audio/transcriptions",
234-
json=payload,
235-
headers=headers,
236-
timeout=aiohttp.ClientTimeout(total=15000)
237-
) as response:
218+
async with session.post(f"{self.base_url}/audio/transcriptions", json=payload, headers=headers, timeout=aiohttp.ClientTimeout(total=15000)) as response:
238219
elapsed = time.time() - start_time
239220

240221
if response.status != 200:
241222
eval_logger.info(f"❌ Audio transcription failed with status: {response.status}")
242223
return ""
243224

244225
result = await response.json()
245-
226+
246227
# Extract transcription text from response
247228
# The response format should contain the transcription
248229
if isinstance(result, dict):
249230
# Try common keys for transcription text
250-
transcription = result.get('text') or result.get('transcription') or result.get('result')
231+
transcription = result.get("text") or result.get("transcription") or result.get("result")
251232
eval_logger.info(f"Transcription result for audio {audio_index}: {transcription}")
252233
if transcription:
253234
return transcription
@@ -282,18 +263,18 @@ def _collate(x):
282263
return -len(toks), x[0]
283264

284265
pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding")
285-
266+
286267
# Group requests by their generation_kwargs
287268
re_ords = utils.Collator([reg.args for reg in requests], _collate, grouping=True)
288269
chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None)
289-
270+
290271
# Collect all audios from all chunks first
291272
all_audios = []
292273
all_contexts = []
293274
all_gen_kwargs_list = []
294-
275+
295276
time_start = time.time()
296-
277+
297278
for chunk in chunks:
298279
contexts, all_gen_kwargs, doc_to_visual, doc_id, task, split = zip(*chunk)
299280
task = task[0]
@@ -318,12 +299,12 @@ def _collate(x):
318299
sampling_rate = self.processor.feature_extractor.sampling_rate
319300
assert sampling_rate == SAMPLING_RATE, f"Expected sampling rate {SAMPLING_RATE}, but got {sampling_rate}"
320301
audios = [downsample_audio(audio["array"], audio["sampling_rate"], sampling_rate) for audio in flattened_audios]
321-
302+
322303
# Collect all data
323304
all_audios.extend(audios)
324305
all_contexts.extend(contexts)
325306
all_gen_kwargs_list.extend([gen_kwargs] * len(contexts))
326-
307+
327308
time_end_prep = time.time()
328309
eval_logger.info(f"Preparation time for {len(all_audios)} requests: {time_end_prep - time_start:.2f}s")
329310

@@ -334,7 +315,7 @@ async def run_transcriptions():
334315
return await asyncio.gather(*tasks)
335316

336317
answers = asyncio.run(run_transcriptions())
337-
318+
338319
time_end_process = time.time()
339320

340321
eval_logger.info(f"Total time for {len(all_audios)} requests across all chunks {time_end_process - time_start:.2f}s")
@@ -348,11 +329,11 @@ async def run_transcriptions():
348329
until = gen_kwargs["until"]
349330
if isinstance(until, str):
350331
until = [until]
351-
332+
352333
for term in until:
353334
if len(term) > 0:
354335
ans = ans.split(term)[0]
355-
336+
356337
processed_answers.append(ans)
357338

358339
for ans, context, gen_kwargs in zip(processed_answers, all_contexts, all_gen_kwargs_list):
@@ -364,9 +345,9 @@ async def run_transcriptions():
364345
res = re_ords.get_original(res)
365346

366347
pbar.close()
367-
348+
368349
time_end_process = time.time()
369-
350+
370351
eval_logger.info(f"Total time for {len(all_audios)} requests across all chunks {time_end_process - time_start:.2f}s")
371352

372353
return res

0 commit comments

Comments
 (0)