Skip to content

Commit 488b9d7

Browse files
authored
STT Evaluation: Gemini Files API instead of signed URLs (#685)
1 parent 06e4cb2 commit 488b9d7

File tree

4 files changed

+414
-143
lines changed

4 files changed

+414
-143
lines changed

backend/app/core/batch/gemini.py

Lines changed: 144 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import json
44
import logging
5+
import mimetypes
56
import os
67
import tempfile
78
import time
@@ -11,8 +12,6 @@
1112
from google import genai
1213
from google.genai import types
1314

14-
from app.core.storage_utils import get_mime_from_url
15-
1615
from .base import BATCH_KEY, BatchProvider
1716

1817
logger = logging.getLogger(__name__)
@@ -292,6 +291,46 @@ def _extract_text_from_response_dict(response: dict[str, Any]) -> str:
292291
"""Extract text content from a Gemini response dictionary."""
293292
return extract_text_from_response_dict(response)
294293

294+
def _upload_to_gemini(
295+
self,
296+
content: str | bytes,
297+
suffix: str,
298+
mime_type: str,
299+
display_name: str,
300+
) -> types.File:
301+
"""Write content to a temp file, upload to Gemini, and clean up.
302+
303+
Args:
304+
content: File content (text or binary)
305+
suffix: Temp file suffix (e.g., ".jsonl", ".mp3")
306+
mime_type: MIME type for the upload
307+
display_name: Display name in Gemini
308+
309+
Returns:
310+
Gemini File object
311+
"""
312+
# "w" for text (JSONL batch) or "wb" for binary (audio files for STT),
313+
# since this method accepts content: str | bytes.
314+
mode = "w" if isinstance(content, str) else "wb"
315+
kwargs: dict[str, Any] = {"suffix": suffix, "delete": False, "mode": mode}
316+
if mode == "w":
317+
kwargs["encoding"] = "utf-8"
318+
319+
with tempfile.NamedTemporaryFile(**kwargs) as tmp_file:
320+
tmp_file.write(content)
321+
tmp_path = tmp_file.name
322+
323+
try:
324+
return self._client.files.upload(
325+
file=tmp_path,
326+
config=types.UploadFileConfig(
327+
display_name=display_name,
328+
mime_type=mime_type,
329+
),
330+
)
331+
finally:
332+
os.unlink(tmp_path)
333+
295334
def upload_file(self, content: str, purpose: str = "batch") -> str:
296335
"""Upload a JSONL file to Gemini Files API.
297336
@@ -305,35 +344,98 @@ def upload_file(self, content: str, purpose: str = "batch") -> str:
305344
logger.info(f"[upload_file] Uploading file to Gemini | bytes={len(content)}")
306345

307346
try:
308-
with tempfile.NamedTemporaryFile(
309-
suffix=".jsonl", delete=False, mode="w", encoding="utf-8"
310-
) as tmp_file:
311-
tmp_file.write(content)
312-
tmp_path = tmp_file.name
347+
uploaded_file = self._upload_to_gemini(
348+
content=content,
349+
suffix=".jsonl",
350+
mime_type="jsonl",
351+
display_name=f"batch-input-{int(time.time())}",
352+
)
313353

314-
try:
315-
uploaded_file = self._client.files.upload(
316-
file=tmp_path,
317-
config=types.UploadFileConfig(
318-
display_name=f"batch-input-{int(time.time())}",
319-
mime_type="jsonl",
320-
),
321-
)
354+
logger.info(
355+
f"[upload_file] Uploaded file to Gemini | "
356+
f"file_name={uploaded_file.name}"
357+
)
322358

323-
logger.info(
324-
f"[upload_file] Uploaded file to Gemini | "
325-
f"file_name={uploaded_file.name}"
326-
)
359+
return uploaded_file.name
360+
361+
except Exception as e:
362+
logger.error(f"[upload_file] Failed to upload file to Gemini | {e}")
363+
raise
327364

328-
return uploaded_file.name
365+
def upload_audio_file(
366+
self,
367+
content: bytes,
368+
mime_type: str,
369+
display_name: str | None = None,
370+
) -> tuple[str, str]:
371+
"""Upload an audio file to Gemini File API.
329372
330-
finally:
331-
os.unlink(tmp_path)
373+
Args:
374+
content: Raw audio file bytes
375+
mime_type: MIME type of the audio (e.g., 'audio/mpeg')
376+
display_name: Optional display name for the file
377+
378+
Returns:
379+
Tuple of (file_name, file_uri):
380+
- file_name: Short name for API calls (e.g., "files/xxx")
381+
- file_uri: Full URI for use in batch requests
382+
(e.g., "https://generativelanguage.googleapis.com/v1beta/files/xxx")
383+
"""
384+
display_name = display_name or f"stt-audio-{int(time.time())}"
385+
logger.info(
386+
f"[upload_audio_file] Uploading audio to Gemini | "
387+
f"bytes={len(content)} | mime_type={mime_type} | display_name={display_name}"
388+
)
389+
390+
try:
391+
uploaded_file = self._upload_to_gemini(
392+
content=content,
393+
suffix=mimetypes.guess_extension(mime_type) or ".bin",
394+
mime_type=mime_type,
395+
display_name=display_name,
396+
)
397+
398+
logger.info(
399+
f"[upload_audio_file] Uploaded audio to Gemini | "
400+
f"file_name={uploaded_file.name} | file_uri={uploaded_file.uri}"
401+
)
402+
403+
return uploaded_file.name, uploaded_file.uri
332404

333405
except Exception as e:
334-
logger.error(f"[upload_file] Failed to upload file to Gemini | {e}")
406+
logger.error(f"[upload_audio_file] Failed to upload audio to Gemini | {e}")
335407
raise
336408

409+
def delete_files(self, file_names: list[str]) -> tuple[int, int]:
410+
"""Delete files from Gemini File API.
411+
412+
Args:
413+
file_names: List of Gemini file names to delete (e.g., ["files/xxx", ...])
414+
415+
Returns:
416+
Tuple of (success_count, failure_count)
417+
"""
418+
success_count = 0
419+
failure_count = 0
420+
421+
for name in file_names:
422+
try:
423+
self._client.files.delete(name=name)
424+
success_count += 1
425+
except Exception as e:
426+
failure_count += 1
427+
logger.warning(
428+
f"[delete_files] Failed to delete Gemini file | "
429+
f"file_name={name} | error={e}"
430+
)
431+
432+
logger.info(
433+
f"[delete_files] Gemini file cleanup complete | "
434+
f"deleted={success_count}, failed={failure_count}"
435+
)
436+
437+
return success_count, failure_count
438+
337439
def download_file(self, file_id: str) -> str:
338440
"""Download a file from Gemini Files API.
339441
@@ -387,18 +489,20 @@ def _extract_text_from_response(response: Any) -> str:
387489

388490

389491
def create_stt_batch_requests(
390-
signed_urls: list[str],
492+
file_uris: list[str],
493+
mime_types: list[str],
391494
prompt: str,
392495
keys: list[str] | None = None,
393496
) -> list[dict[str, Any]]:
394497
"""
395-
Create batch API requests for Gemini STT using signed URLs.
498+
Create batch API requests for Gemini STT using Gemini File API URIs.
396499
397500
This function generates request payloads in Gemini's JSONL batch format
398-
using signed URLs directly. MIME types are automatically detected from the URL path.
501+
using file URIs from the Gemini File API.
399502
400503
Args:
401-
signed_urls: List of signed URLs pointing to audio files
504+
file_uris: List of Gemini file URIs (e.g., "files/abc123")
505+
mime_types: List of MIME types corresponding to each file URI
402506
prompt: Transcription prompt/instructions for the model
403507
keys: Optional list of custom IDs for tracking results. If not provided,
404508
uses 0-indexed integers as strings.
@@ -408,25 +512,25 @@ def create_stt_batch_requests(
408512
{"key": "sample-1", "request": {"contents": [...]}}
409513
410514
Example:
411-
>>> urls = ["https://bucket.s3.amazonaws.com/audio.mp3?..."]
515+
>>> uris = ["files/abc123"]
516+
>>> mime_types = ["audio/mpeg"]
412517
>>> prompt = "Transcribe this audio file."
413-
>>> requests = create_stt_batch_requests(urls, prompt, keys=["sample-1"])
518+
>>> requests = create_stt_batch_requests(uris, mime_types, prompt, keys=["sample-1"])
414519
>>> provider.create_batch(requests, {"display_name": "stt-batch"})
415520
"""
416-
if keys is not None and len(keys) != len(signed_urls):
521+
if len(file_uris) != len(mime_types):
522+
raise ValueError(
523+
f"Length of file_uris ({len(file_uris)}) must match mime_types ({len(mime_types)})"
524+
)
525+
526+
if keys is not None and len(keys) != len(file_uris):
417527
raise ValueError(
418-
f"Length of keys ({len(keys)}) must match signed_urls ({len(signed_urls)})"
528+
f"Length of keys ({len(keys)}) must match file_uris ({len(file_uris)})"
419529
)
420530

421531
requests = []
422-
for i, url in enumerate(signed_urls):
423-
mime_type = get_mime_from_url(url)
424-
if mime_type is None:
425-
logger.warning(
426-
f"[create_stt_batch_requests] Could not determine MIME type for URL | "
427-
f"index={i} | defaulting to audio/mpeg"
428-
)
429-
mime_type = "audio/mpeg"
532+
for i, uri in enumerate(file_uris):
533+
mime_type = mime_types[i]
430534

431535
# Use provided key or generate from index
432536
key = keys[i] if keys is not None else str(i)
@@ -439,7 +543,7 @@ def create_stt_batch_requests(
439543
{
440544
"parts": [
441545
{"text": prompt},
442-
{"file_data": {"mime_type": mime_type, "file_uri": url}},
546+
{"file_data": {"mime_type": mime_type, "file_uri": uri}},
443547
],
444548
"role": "user",
445549
}

0 commit comments

Comments
 (0)