Skip to content

Commit f7328a4

Browse files
authored
Fix server (STT, TTS) (#460)
* Add python-multipart dependency to pyproject.toml for multipart form handling * Refactor tokenizer usage in Whisper model files - Removed direct imports and references to the Tokenizer class in favor of using model methods for tokenizer retrieval. - Updated function signatures to accept a generic tokenizer parameter instead of a specific Tokenizer type. - Enhanced the `detect_language` and `get_suppress_tokens` functions to utilize the model's tokenizer method, improving flexibility and reducing dependencies. - Deleted unused tokenizer asset files to streamline the codebase. * Add instruct and verbose fields to SpeechRequest model - Updated the SpeechRequest model to include an optional 'instruct' field for additional instructions. - Added a 'verbose' field to control output verbosity during audio generation. - Refactored the generate_audio function to utilize the new fields, enhancing flexibility in audio processing. * Add TranscriptionRequest model and update stt_transcriptions function - Introduced a new TranscriptionRequest model to encapsulate parameters for audio transcription, including fields for language, verbosity, and streaming options. - Updated the stt_transcriptions function to utilize the new model, enhancing parameter management and flexibility. - Implemented logic to handle streaming results and filter generation parameters based on the model's signature, improving the transcription process. * Add streaming transcription support and refactor STT model handling - Introduced a new `generate_transcription_stream` function to handle streaming transcription, yielding results in real-time and managing temporary file cleanup. - Updated the `stt_transcriptions` endpoint to utilize the new streaming function, enhancing the responsiveness of audio transcription. - Refactored the `Qwen3ASRModel` to support streaming transcriptions, including a new `StreamingResult` data class for structured output. - Improved model remapping logic in `get_model_category` to prioritize explicit remapping matches, enhancing model loading flexibility. - Suppressed warnings during tokenizer loading in both `Qwen3ASRModel` and `ForcedAlignerModel` to improve user experience during model initialization. * format * fix tests * Enhance chunk processing in Qwen3ASRModel with progress indication - Added a tqdm progress bar to the chunk processing loop in the Qwen3ASRModel, improving user feedback during audio processing. - The progress bar is configurable based on verbosity and the number of chunks, enhancing the overall user experience. * Refactor audio chunk parameters in Qwen3ASRModel and related functions - Renamed `max_chunk_sec` and `min_chunk_sec` to `chunk_duration` and `min_chunk_duration` respectively for clarity and consistency across the codebase. - Updated function signatures and internal logic to reflect the new parameter names, ensuring proper handling of audio chunk durations. - Adjusted test cases to align with the updated parameter names, maintaining test integrity. * format
1 parent be5676f commit f7328a4

File tree

16 files changed

+419
-100917
lines changed

16 files changed

+419
-100917
lines changed

mlx_audio/server.py

Lines changed: 86 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import argparse
99
import asyncio
10+
import inspect
1011
import io
1112
import json
1213
import os
@@ -153,6 +154,7 @@ def setup_cors(app: FastAPI, allowed_origins: List[str]):
153154
class SpeechRequest(BaseModel):
154155
model: str
155156
input: str
157+
instruct: str | None = None
156158
voice: str | None = None
157159
speed: float | None = 1.0
158160
gender: str | None = "male"
@@ -165,6 +167,20 @@ class SpeechRequest(BaseModel):
165167
top_k: int | None = 40
166168
repetition_penalty: float | None = 1.0
167169
response_format: str | None = "mp3"
170+
verbose: bool = False
171+
172+
173+
class TranscriptionRequest(BaseModel):
174+
model: str
175+
language: str | None = None
176+
verbose: bool = False
177+
max_tokens: int = 128
178+
chunk_duration: float = 30.0
179+
frame_threshold: int = 25
180+
stream: bool = False
181+
context: str | None = None
182+
prefill_step_size: int = 2048
183+
text: str | None = None
168184

169185

170186
# Initialize the ModelProvider
@@ -234,7 +250,7 @@ async def remove_model(model_name: str):
234250
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
235251

236252

237-
async def generate_audio(model, payload: SpeechRequest, verbose: bool = False):
253+
async def generate_audio(model, payload: SpeechRequest):
238254
# Load reference audio if provided
239255
ref_audio = payload.ref_audio
240256
if ref_audio and isinstance(ref_audio, str):
@@ -258,13 +274,15 @@ async def generate_audio(model, payload: SpeechRequest, verbose: bool = False):
258274
speed=payload.speed,
259275
gender=payload.gender,
260276
pitch=payload.pitch,
277+
instruct=payload.instruct,
261278
lang_code=payload.lang_code,
262279
ref_audio=ref_audio,
263280
ref_text=payload.ref_text,
264281
temperature=payload.temperature,
265282
top_p=payload.top_p,
266283
top_k=payload.top_k,
267284
repetition_penalty=payload.repetition_penalty,
285+
verbose=payload.verbose,
268286
):
269287

270288
sample_rate = result.sample_rate
@@ -287,25 +305,87 @@ async def tts_speech(payload: SpeechRequest):
287305
)
288306

289307

308+
def generate_transcription_stream(stt_model, tmp_path: str, gen_kwargs: dict):
309+
"""Generator that yields transcription chunks and cleans up temp file."""
310+
try:
311+
# Call generate with stream=True (models handle streaming internally)
312+
result = stt_model.generate(tmp_path, **gen_kwargs)
313+
314+
# Check if result is a generator (streaming mode)
315+
if hasattr(result, "__iter__") and hasattr(result, "__next__"):
316+
accumulated_text = ""
317+
for chunk in result:
318+
# Handle different chunk types (string tokens vs structured chunks)
319+
if isinstance(chunk, str):
320+
accumulated_text += chunk
321+
chunk_data = {"text": chunk, "accumulated": accumulated_text}
322+
else:
323+
# Structured chunk (e.g., Whisper streaming)
324+
chunk_data = {
325+
"text": chunk.text,
326+
"start": getattr(chunk, "start_time", None),
327+
"end": getattr(chunk, "end_time", None),
328+
"is_final": getattr(chunk, "is_final", None),
329+
"language": getattr(chunk, "language", None),
330+
}
331+
yield json.dumps(sanitize_for_json(chunk_data)) + "\n"
332+
else:
333+
# Not a generator, yield the full result
334+
yield json.dumps(sanitize_for_json(result)) + "\n"
335+
finally:
336+
if os.path.exists(tmp_path):
337+
os.remove(tmp_path)
338+
339+
290340
@app.post("/v1/audio/transcriptions")
291341
async def stt_transcriptions(
292342
file: UploadFile = File(...),
293343
model: str = Form(...),
294344
language: Optional[str] = Form(None),
345+
verbose: bool = Form(False),
346+
max_tokens: int = Form(128),
347+
chunk_duration: float = Form(30.0),
348+
frame_threshold: int = Form(25),
349+
stream: bool = Form(False),
350+
context: Optional[str] = Form(None),
351+
prefill_step_size: int = Form(2048),
352+
text: Optional[str] = Form(None),
295353
):
296354
"""Transcribe audio using an STT model in OpenAI format."""
355+
# Create TranscriptionRequest from form fields
356+
payload = TranscriptionRequest(
357+
model=model,
358+
language=language,
359+
verbose=verbose,
360+
max_tokens=max_tokens,
361+
chunk_duration=chunk_duration,
362+
frame_threshold=frame_threshold,
363+
stream=stream,
364+
context=context,
365+
prefill_step_size=prefill_step_size,
366+
text=text,
367+
)
368+
297369
data = await file.read()
298370
tmp = io.BytesIO(data)
299371
audio, sr = audio_read(tmp, always_2d=False)
300372
tmp.close()
301373
tmp_path = f"/tmp/{time.time()}.mp3"
302374
audio_write(tmp_path, audio, sr)
303375

304-
stt_model = model_provider.load_model(model)
305-
result = stt_model.generate(tmp_path)
306-
os.remove(tmp_path)
307-
# Sanitize NaN values for JSON serialization
308-
return sanitize_for_json(result)
376+
stt_model = model_provider.load_model(payload.model)
377+
378+
# Build kwargs for generate, filtering None values
379+
gen_kwargs = payload.model_dump(exclude={"model"}, exclude_none=True)
380+
381+
# Filter kwargs to only include parameters the model's generate method accepts
382+
signature = inspect.signature(stt_model.generate)
383+
gen_kwargs = {k: v for k, v in gen_kwargs.items() if k in signature.parameters}
384+
385+
return StreamingResponse(
386+
generate_transcription_stream(stt_model, tmp_path, gen_kwargs),
387+
media_type="application/x-ndjson",
388+
)
309389

310390

311391
@app.websocket("/v1/audio/transcriptions/realtime")

mlx_audio/stt/generate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def parse_args():
8383
"--gen-kwargs",
8484
type=json.loads,
8585
default=None,
86-
help='Additional generate kwargs as JSON (e.g. \'{"max_chunk_sec": 600, "min_chunk_sec": 1.0}\')',
86+
help="Additional generate kwargs as JSON (e.g. '{\"min_chunk_duration\": 1.0}')",
8787
)
8888
parser.add_argument(
8989
"--text",

0 commit comments

Comments
 (0)