Skip to content

Commit de2c9ef

Browse files
NickLuccheDarkLight1337
authored andcommitted
[Frontend] Gemma3n audio transcriptions/translations endpoint (vllm-project#23735)
Signed-off-by: NickLucche <[email protected]> Co-authored-by: Cyrus Leung <[email protected]>
1 parent 4ef05c6 commit de2c9ef

File tree

9 files changed

+189
-63
lines changed

9 files changed

+189
-63
lines changed
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import pytest
4+
5+
from vllm.assets.audio import AudioAsset
6+
7+
8+
@pytest.fixture
9+
def mary_had_lamb():
10+
path = AudioAsset('mary_had_lamb').get_local_path()
11+
with open(str(path), "rb") as f:
12+
yield f
13+
14+
15+
@pytest.fixture
16+
def winning_call():
17+
path = AudioAsset('winning_call').get_local_path()
18+
with open(str(path), "rb") as f:
19+
yield f
20+
21+
22+
@pytest.fixture
23+
def foscolo():
24+
# Test translation it->en
25+
path = AudioAsset('azacinto_foscolo').get_local_path()
26+
with open(str(path), "rb") as f:
27+
yield f

tests/entrypoints/openai/test_transcription_validation.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212
import pytest_asyncio
1313
import soundfile as sf
1414

15-
from vllm.assets.audio import AudioAsset
16-
1715
from ...utils import RemoteOpenAIServer
1816

1917
MODEL_NAME = "openai/whisper-large-v3-turbo"
@@ -24,20 +22,6 @@
2422
]
2523

2624

27-
@pytest.fixture
28-
def mary_had_lamb():
29-
path = AudioAsset('mary_had_lamb').get_local_path()
30-
with open(str(path), "rb") as f:
31-
yield f
32-
33-
34-
@pytest.fixture
35-
def winning_call():
36-
path = AudioAsset('winning_call').get_local_path()
37-
with open(str(path), "rb") as f:
38-
yield f
39-
40-
4125
@pytest.fixture(scope="module")
4226
def server():
4327
with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as remote_server:
@@ -76,6 +60,25 @@ async def test_basic_audio(mary_had_lamb, model_name):
7660
assert out_usage["seconds"] == 16, out_usage["seconds"]
7761

7862

63+
@pytest.mark.asyncio
64+
async def test_basic_audio_gemma(foscolo):
65+
# Gemma accuracy on some of the audio samples we use is particularly bad,
66+
# hence we use a different one here. WER is evaluated separately.
67+
model_name = "google/gemma-3n-E2B-it"
68+
server_args = ["--enforce-eager"]
69+
70+
with RemoteOpenAIServer(model_name, server_args) as remote_server:
71+
client = remote_server.get_async_client()
72+
transcription = await client.audio.transcriptions.create(
73+
model=model_name,
74+
file=foscolo,
75+
language="it",
76+
response_format="text",
77+
temperature=0.0)
78+
out = json.loads(transcription)['text']
79+
assert "da cui vergine nacque Venere" in out
80+
81+
7982
@pytest.mark.asyncio
8083
async def test_non_asr_model(winning_call):
8184
# text to text model

tests/entrypoints/openai/test_translation_validation.py

Lines changed: 44 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -12,32 +12,24 @@
1212
import pytest_asyncio
1313
import soundfile as sf
1414

15-
from vllm.assets.audio import AudioAsset
16-
1715
from ...utils import RemoteOpenAIServer
1816

19-
MODEL_NAME = "openai/whisper-small"
2017
SERVER_ARGS = ["--enforce-eager"]
2118

2219

23-
@pytest.fixture
24-
def foscolo():
25-
# Test translation it->en
26-
path = AudioAsset('azacinto_foscolo').get_local_path()
27-
with open(str(path), "rb") as f:
28-
yield f
29-
30-
31-
@pytest.fixture(scope="module")
32-
def server():
33-
with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as remote_server:
34-
yield remote_server
20+
@pytest.fixture(scope="module",
21+
params=["openai/whisper-small", "google/gemma-3n-E2B-it"])
22+
def server(request):
23+
# Parametrize over model name
24+
with RemoteOpenAIServer(request.param, SERVER_ARGS) as remote_server:
25+
yield remote_server, request.param
3526

3627

3728
@pytest_asyncio.fixture
38-
async def client(server):
29+
async def client_and_model(server):
30+
server, model_name = server
3931
async with server.get_async_client() as async_client:
40-
yield async_client
32+
yield async_client, model_name
4133

4234

4335
@pytest.mark.asyncio
@@ -56,27 +48,29 @@ async def test_non_asr_model(foscolo):
5648

5749
# NOTE: (NickLucche) the large-v3-turbo model was not trained on translation!
5850
@pytest.mark.asyncio
59-
async def test_basic_audio(foscolo, client):
51+
async def test_basic_audio(foscolo, client_and_model):
52+
client, model_name = client_and_model
6053
translation = await client.audio.translations.create(
61-
model=MODEL_NAME,
54+
model=model_name,
6255
file=foscolo,
6356
response_format="text",
64-
# TODO remove once language detection is implemented
65-
extra_body=dict(language="it"),
57+
# TODO remove `language="it"` once language detection is implemented
58+
extra_body=dict(language="it", to_language="en"),
6659
temperature=0.0)
6760
out = json.loads(translation)['text'].strip().lower()
6861
assert "greek sea" in out
6962

7063

7164
@pytest.mark.asyncio
72-
async def test_audio_prompt(foscolo, client):
65+
async def test_audio_prompt(foscolo, client_and_model):
66+
client, model_name = client_and_model
7367
# Condition whisper on starting text
7468
prompt = "Nor have I ever"
7569
transcription = await client.audio.translations.create(
76-
model=MODEL_NAME,
70+
model=model_name,
7771
file=foscolo,
7872
prompt=prompt,
79-
extra_body=dict(language="it"),
73+
extra_body=dict(language="it", to_language="en"),
8074
response_format="text",
8175
temperature=0.0)
8276
out = json.loads(transcription)['text']
@@ -85,22 +79,27 @@ async def test_audio_prompt(foscolo, client):
8579

8680

8781
@pytest.mark.asyncio
88-
async def test_streaming_response(foscolo, client, server):
82+
async def test_streaming_response(foscolo, client_and_model, server):
83+
client, model_name = client_and_model
8984
translation = ""
9085
res_no_stream = await client.audio.translations.create(
91-
model=MODEL_NAME,
86+
model=model_name,
9287
file=foscolo,
9388
response_format="json",
94-
extra_body=dict(language="it"),
89+
extra_body=dict(language="it", to_language="en", seed=42),
9590
temperature=0.0)
91+
9692
# Stream via HTTPX since OpenAI translation client doesn't expose streaming
93+
server, model_name = server
9794
url = server.url_for("v1/audio/translations")
9895
headers = {"Authorization": f"Bearer {server.DUMMY_API_KEY}"}
9996
data = {
100-
"model": MODEL_NAME,
97+
"model": model_name,
10198
"language": "it",
99+
"to_language": "en",
102100
"stream": True,
103101
"temperature": 0.0,
102+
"seed": 42,
104103
}
105104
foscolo.seek(0)
106105
async with httpx.AsyncClient() as http_client:
@@ -121,16 +120,24 @@ async def test_streaming_response(foscolo, client, server):
121120
text = chunk["choices"][0].get("delta", {}).get("content")
122121
translation += text or ""
123122

124-
assert translation == res_no_stream.text
123+
res_stream = translation.split()
124+
# NOTE There's a small non-deterministic issue here, likely in the attn
125+
# computation, which will cause a few tokens to be different, while still
126+
# being very close semantically.
127+
assert sum([
128+
x == y for x, y in zip(res_stream, res_no_stream.text.split())
129+
]) >= len(res_stream) * 0.9
125130

126131

127132
@pytest.mark.asyncio
128-
async def test_stream_options(foscolo, client, server):
133+
async def test_stream_options(foscolo, server):
134+
server, model_name = server
129135
url = server.url_for("v1/audio/translations")
130136
headers = {"Authorization": f"Bearer {server.DUMMY_API_KEY}"}
131137
data = {
132-
"model": MODEL_NAME,
138+
"model": model_name,
133139
"language": "it",
140+
"to_language": "en",
134141
"stream": True,
135142
"stream_include_usage": True,
136143
"stream_continuous_usage_stats": True,
@@ -164,7 +171,10 @@ async def test_stream_options(foscolo, client, server):
164171

165172

166173
@pytest.mark.asyncio
167-
async def test_long_audio_request(foscolo, client):
174+
async def test_long_audio_request(foscolo, client_and_model):
175+
client, model_name = client_and_model
176+
if model_name == "google/gemma-3n-E2B-it":
177+
pytest.skip("Gemma3n does not support long audio requests")
168178
foscolo.seek(0)
169179
audio, sr = librosa.load(foscolo)
170180
repeated_audio = np.tile(audio, 2)
@@ -173,9 +183,9 @@ async def test_long_audio_request(foscolo, client):
173183
sf.write(buffer, repeated_audio, sr, format='WAV')
174184
buffer.seek(0)
175185
translation = await client.audio.translations.create(
176-
model=MODEL_NAME,
186+
model=model_name,
177187
file=buffer,
178-
extra_body=dict(language="it"),
188+
extra_body=dict(language="it", to_language="en"),
179189
response_format="text",
180190
temperature=0.0)
181191
out = json.loads(translation)['text'].strip().lower()

vllm/entrypoints/openai/protocol.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2175,6 +2175,13 @@ class TranscriptionRequest(OpenAIBaseModel):
21752175
)
21762176
# --8<-- [end:transcription-extra-params]
21772177

2178+
to_language: Optional[str] = None
2179+
"""The language of the output audio we transcribe to.
2180+
2181+
Please note that this is not currently used by supported models at this
2182+
time, but it is a placeholder for future use, matching translation api.
2183+
"""
2184+
21782185
# --8<-- [start:transcription-sampling-params]
21792186
temperature: float = Field(default=0.0)
21802187
"""The sampling temperature, between 0 and 1.
@@ -2408,6 +2415,9 @@ class TranslationRequest(OpenAIBaseModel):
24082415

24092416
# TODO support additional sampling parameters
24102417
# --8<-- [start:translation-sampling-params]
2418+
seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
2419+
"""The seed to use for sampling."""
2420+
24112421
temperature: float = Field(default=0.0)
24122422
"""The sampling temperature, between 0 and 1.
24132423
@@ -2427,6 +2437,14 @@ class TranslationRequest(OpenAIBaseModel):
24272437
will improve accuracy.
24282438
"""
24292439

2440+
to_language: Optional[str] = None
2441+
"""The language of the input audio we translate to.
2442+
2443+
Please note that this is not supported by all models, refer to the specific
2444+
model documentation for more details.
2445+
For instance, Whisper only supports `to_language=en`.
2446+
"""
2447+
24302448
stream: Optional[bool] = False
24312449
"""Custom field not present in the original OpenAI definition. When set,
24322450
it will enable output to be streamed in a similar fashion as the Chat
@@ -2458,6 +2476,7 @@ def to_sampling_params(
24582476

24592477
return SamplingParams.from_optional(temperature=temperature,
24602478
max_tokens=max_tokens,
2479+
seed=self.seed,
24612480
output_kind=RequestOutputKind.DELTA
24622481
if self.stream \
24632482
else RequestOutputKind.FINAL_ONLY)

vllm/entrypoints/openai/speech_to_text.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,9 @@ async def _preprocess_speech_to_text(
8989
) -> tuple[list[PromptType], float]:
9090
# Validate request
9191
language = self.model_cls.validate_language(request.language)
92+
# Skip to_language validation to avoid extra logging for Whisper.
93+
to_language = self.model_cls.validate_language(request.to_language) \
94+
if request.to_language else None
9295

9396
if len(audio_data) / 1024**2 > self.max_audio_filesize_mb:
9497
raise ValueError("Maximum file size exceeded.")
@@ -112,7 +115,9 @@ async def _preprocess_speech_to_text(
112115
model_config=self.model_config,
113116
language=language,
114117
task_type=self.task_type,
115-
request_prompt=request.prompt)
118+
request_prompt=request.prompt,
119+
to_language=to_language,
120+
)
116121
prompts.append(prompt)
117122
return prompts, duration
118123

0 commit comments

Comments
 (0)