Skip to content

Commit 4090ff2

Browse files
chore(format): run black on dev (#915)
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent d582fd5 commit 4090ff2

File tree

3 files changed

+109
-63
lines changed

3 files changed

+109
-63
lines changed

examples/api/openai_api.py

Lines changed: 84 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,12 @@
1212
- Use asyncio.Lock to manage model access, improving concurrency performance
1313
- Load and manage speaker embedding files to support personalized speech synthesis
1414
"""
15+
1516
import io
1617
import os
1718
import sys
1819
import asyncio
19-
import time
20+
import time
2021
from typing import Optional, Dict
2122
from fastapi import FastAPI, HTTPException
2223
from fastapi.responses import StreamingResponse, JSONResponse
@@ -57,48 +58,72 @@
5758
# Allowed audio formats
5859
ALLOWED_FORMATS = {"mp3", "wav", "ogg"}
5960

61+
6062
@app.on_event("startup")
6163
async def startup_event():
6264
"""Load ChatTTS model and default speaker embedding when the application starts"""
6365
# Initialize ChatTTS and async lock
6466
app.state.chat = ChatTTS.Chat(get_logger("ChatTTS"))
6567
app.state.model_lock = asyncio.Lock() # Use async lock instead of thread lock
66-
68+
6769
# Register text normalizers
6870
app.state.chat.normalizer.register("en", normalizer_en_nemo_text())
6971
app.state.chat.normalizer.register("zh", normalizer_zh_tn())
70-
72+
7173
logger.info("Initializing ChatTTS...")
7274
if app.state.chat.load(source="huggingface"):
7375
logger.info("Model loaded successfully.")
7476
else:
7577
logger.error("Model loading failed, exiting application.")
7678
raise RuntimeError("Failed to load ChatTTS model")
77-
79+
7880
# Load default speaker embedding
7981
# Preload all supported speaker embeddings into memory at startup to avoid repeated loading during runtime
8082
app.state.spk_emb_map = {}
8183
for voice, spk_path in VOICE_MAP.items():
8284
if os.path.exists(spk_path):
83-
app.state.spk_emb_map[voice] = torch.load(spk_path, map_location=torch.device("cpu"))
85+
app.state.spk_emb_map[voice] = torch.load(
86+
spk_path, map_location=torch.device("cpu")
87+
)
8488
logger.info(f"Preloading speaker embedding: {voice} -> {spk_path}")
8589
else:
8690
logger.warning(f"Speaker embedding not found: {spk_path}, skipping preload")
8791
app.state.spk_emb = app.state.spk_emb_map.get("default") # Default embedding
8892

93+
8994
# Request parameter whitelist
90-
ALLOWED_PARAMS = {"model", "input", "voice", "response_format", "speed", "stream", "output_format"}
95+
ALLOWED_PARAMS = {
96+
"model",
97+
"input",
98+
"voice",
99+
"response_format",
100+
"speed",
101+
"stream",
102+
"output_format",
103+
}
104+
91105

92106
class OpenAITTSRequest(BaseModel):
93107
"""OpenAI TTS request data model"""
108+
94109
model: str = Field(..., description="Speech synthesis model, fixed as 'tts-1'")
95-
input: str = Field(..., description="Text content to synthesize", max_length=2048) # Length limit
96-
voice: Optional[str] = Field("default", description="Voice selection, supports: default, alloy, echo")
97-
response_format: Optional[str] = Field("mp3", description="Audio format: mp3, wav, ogg")
98-
speed: Optional[float] = Field(1.0, ge=0.5, le=2.0, description="Speed, range 0.5-2.0")
110+
input: str = Field(
111+
..., description="Text content to synthesize", max_length=2048
112+
) # Length limit
113+
voice: Optional[str] = Field(
114+
"default", description="Voice selection, supports: default, alloy, echo"
115+
)
116+
response_format: Optional[str] = Field(
117+
"mp3", description="Audio format: mp3, wav, ogg"
118+
)
119+
speed: Optional[float] = Field(
120+
1.0, ge=0.5, le=2.0, description="Speed, range 0.5-2.0"
121+
)
99122
stream: Optional[bool] = Field(False, description="Whether to stream")
100123
output_format: Optional[str] = "mp3" # Optional formats: mp3, wav, ogg
101-
extra_params: Dict[str, Optional[str]] = Field(default_factory=dict, description="Unsupported extra parameters")
124+
extra_params: Dict[str, Optional[str]] = Field(
125+
default_factory=dict, description="Unsupported extra parameters"
126+
)
102127

103128
@classmethod
104129
def validate_request(cls, request_data: Dict):
@@ -109,31 +134,38 @@ def validate_request(cls, request_data: Dict):
109134
logger.warning(f"Ignoring unsupported parameters: {unsupported_params}")
110135
return {key: request_data[key] for key in ALLOWED_PARAMS if key in request_data}
111136

137+
112138
# Unified error response
113139
@app.exception_handler(Exception)
114140
async def custom_exception_handler(request, exc):
115141
"""Custom exception handler"""
116142
logger.error(f"Error: {str(exc)}")
117143
return JSONResponse(
118144
status_code=getattr(exc, "status_code", 500),
119-
content={"error": {"message": str(exc), "type": exc.__class__.__name__}}
145+
content={"error": {"message": str(exc), "type": exc.__class__.__name__}},
120146
)
121147

148+
122149
@app.post("/v1/audio/speech")
123150
async def generate_voice(request_data: Dict):
124151
"""Handle speech synthesis request"""
125152
request_data = OpenAITTSRequest.validate_request(request_data)
126153
request = OpenAITTSRequest(**request_data)
127-
128-
logger.info(f"Received request: text={request.input}..., voice={request.voice}, stream={request.stream}")
129-
154+
155+
logger.info(
156+
f"Received request: text={request.input}..., voice={request.voice}, stream={request.stream}"
157+
)
158+
130159
# Validate audio format
131160
if request.response_format not in ALLOWED_FORMATS:
132-
raise HTTPException(400, detail=f"Unsupported audio format: {request.response_format}, supported formats: {', '.join(ALLOWED_FORMATS)}")
161+
raise HTTPException(
162+
400,
163+
detail=f"Unsupported audio format: {request.response_format}, supported formats: {', '.join(ALLOWED_FORMATS)}",
164+
)
133165

134166
# Load speaker embedding for the specified voice
135167
spk_emb = app.state.spk_emb_map.get(request.voice, app.state.spk_emb)
136-
168+
137169
# Inference parameters
138170
params_infer_main = {
139171
"text": [request.input],
@@ -145,13 +177,13 @@ async def generate_voice(request_data: Dict):
145177
"audio_seed": 12345678,
146178
# "text_seed": 87654321, # Random seed for text processing, used to control text refinement
147179
"do_text_normalization": True, # Perform text normalization
148-
"do_homophone_replacement": True, # Perform homophone replacement
180+
"do_homophone_replacement": True, # Perform homophone replacement
149181
}
150-
182+
151183
# Inference code parameters
152184
params_infer_code = app.state.chat.InferCodeParams(
153-
#prompt=f"[speed_{int(request.speed * 10)}]", # Convert to format supported by ChatTTS
154-
prompt="[speed_5]",
185+
# prompt=f"[speed_{int(request.speed * 10)}]", # Convert to format supported by ChatTTS
186+
prompt="[speed_5]",
155187
top_P=0.5,
156188
top_K=10,
157189
temperature=0.1,
@@ -166,21 +198,21 @@ async def generate_voice(request_data: Dict):
166198
txt_smp=None,
167199
stream_batch=24,
168200
stream_speed=12000,
169-
pass_first_n_batches=2
201+
pass_first_n_batches=2,
170202
)
171203

172204
try:
173205
async with app.state.model_lock:
174206
wavs = app.state.chat.infer(
175-
text = params_infer_main["text"],
176-
stream = params_infer_main["stream"],
177-
lang = params_infer_main["lang"],
178-
skip_refine_text = params_infer_main["skip_refine_text"],
179-
use_decoder = params_infer_main["use_decoder"],
180-
do_text_normalization = params_infer_main["do_text_normalization"],
181-
do_homophone_replacement = params_infer_main['do_homophone_replacement'],
182-
# params_refine_text = params_refine_text,
183-
params_infer_code=params_infer_code,
207+
text=params_infer_main["text"],
208+
stream=params_infer_main["stream"],
209+
lang=params_infer_main["lang"],
210+
skip_refine_text=params_infer_main["skip_refine_text"],
211+
use_decoder=params_infer_main["use_decoder"],
212+
do_text_normalization=params_infer_main["do_text_normalization"],
213+
do_homophone_replacement=params_infer_main["do_homophone_replacement"],
214+
# params_refine_text = params_refine_text,
215+
params_infer_code=params_infer_code,
184216
)
185217
except Exception as e:
186218
raise HTTPException(500, detail=f"Speech synthesis failed: {str(e)}")
@@ -189,7 +221,7 @@ def generate_wav_header(sample_rate=24000, bits_per_sample=16, channels=1):
189221
"""Generate WAV file header (without data length)"""
190222
header = bytearray()
191223
header.extend(b"RIFF")
192-
header.extend(b"\xFF\xFF\xFF\xFF") # File size unknown
224+
header.extend(b"\xff\xff\xff\xff") # File size unknown
193225
header.extend(b"WAVEfmt ")
194226
header.extend((16).to_bytes(4, "little")) # fmt chunk size
195227
header.extend((1).to_bytes(2, "little")) # PCM format
@@ -201,7 +233,7 @@ def generate_wav_header(sample_rate=24000, bits_per_sample=16, channels=1):
201233
header.extend((block_align).to_bytes(2, "little")) # Block align
202234
header.extend((bits_per_sample).to_bytes(2, "little")) # Bits per sample
203235
header.extend(b"data")
204-
header.extend(b"\xFF\xFF\xFF\xFF") # Data size unknown
236+
header.extend(b"\xff\xff\xff\xff") # Data size unknown
205237
return bytes(header)
206238

207239
# Handle audio output format
@@ -210,35 +242,44 @@ def convert_audio(wav, format):
210242
if format == "mp3":
211243
return pcm_arr_to_mp3_view(wav)
212244
elif format == "wav":
213-
return pcm_arr_to_wav_view(wav, include_header=False) # No header in streaming
245+
return pcm_arr_to_wav_view(
246+
wav, include_header=False
247+
) # No header in streaming
214248
elif format == "ogg":
215249
return pcm_arr_to_ogg_view(wav)
216-
return pcm_arr_to_mp3_view(wav)
217-
250+
return pcm_arr_to_mp3_view(wav)
251+
218252
# Return streaming audio data
219253
if request.stream:
220254
first_chunk = True
255+
221256
async def audio_stream():
222257
nonlocal first_chunk
223258
for wav in wavs:
224259
if request.response_format == "wav" and first_chunk:
225260
yield generate_wav_header() # Send WAV header
226261
first_chunk = False
227262
yield convert_audio(wav, request.response_format)
263+
228264
media_type = "audio/wav" if request.response_format == "wav" else "audio/mpeg"
229265
return StreamingResponse(audio_stream(), media_type=media_type)
230-
266+
231267
# Return audio file directly
232-
if request.response_format == 'wav':
268+
if request.response_format == "wav":
233269
music_data = pcm_arr_to_wav_view(wavs[0])
234270
else:
235271
music_data = convert_audio(wavs[0], request.response_format)
236-
237-
return StreamingResponse(io.BytesIO(music_data), media_type="audio/mpeg", headers={
238-
"Content-Disposition": f"attachment; filename=output.{request.response_format}"
239-
})
272+
273+
return StreamingResponse(
274+
io.BytesIO(music_data),
275+
media_type="audio/mpeg",
276+
headers={
277+
"Content-Disposition": f"attachment; filename=output.{request.response_format}"
278+
},
279+
)
280+
240281

241282
@app.get("/health")
242283
async def health_check():
243284
"""Health check endpoint"""
244-
return {"status": "healthy", "model_loaded": bool(app.state.chat)}
285+
return {"status": "healthy", "model_loaded": bool(app.state.chat)}

openai_api.ipynb

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,23 +29,20 @@
2929
"from IPython.display import Audio, display\n",
3030
"\n",
3131
"# Initialize the client\n",
32-
"client = OpenAI(\n",
33-
" api_key=\"dummy-key\",\n",
34-
" base_url=\"http://localhost:8000/v1\"\n",
35-
")\n",
32+
"client = OpenAI(api_key=\"dummy-key\", base_url=\"http://localhost:8000/v1\")\n",
3633
"\n",
3734
"# Generate audio\n",
3835
"response = client.audio.speech.create(\n",
3936
" model=\"tts-1\",\n",
4037
" voice=\"echo\",\n",
41-
" input= \"\"\"\n",
38+
" input=\"\"\"\n",
4239
" 以下是一些中英文对照的话语。 \n",
4340
" 1. 早上好!希望你有美好的一天。Good morning! Wish you a wonderful day. \n",
4441
" 2. 你好呀,最近怎么样?Hello there, how have you been recently? \n",
4542
" 3. 别放弃,你能做到的!Don't give up, you can do it! \n",
4643
" 4. 继续努力,你的付出会有回报的。Keep up the good work, your efforts will pay off.\n",
4744
" \"\"\",\n",
48-
" response_format=\"wav\"\n",
45+
" response_format=\"wav\",\n",
4946
")\n",
5047
"\n",
5148
"# Get audio binary data\n",
@@ -101,19 +98,21 @@
10198
" 4. 继续努力,你的付出会有回报的。Keep up the good work, your efforts will pay off.\n",
10299
" \"\"\",\n",
103100
" \"voice\": \"echo\",\n",
104-
" \"response_format\": \"wav\", \n",
105-
" \"stream\": True\n",
101+
" \"response_format\": \"wav\",\n",
102+
" \"stream\": True,\n",
106103
"}\n",
107104
"\n",
108105
"try:\n",
109-
" response = requests.post(\"http://localhost:8000/v1/audio/speech\", json=payload, stream=True)\n",
106+
" response = requests.post(\n",
107+
" \"http://localhost:8000/v1/audio/speech\", json=payload, stream=True\n",
108+
" )\n",
110109
" response.raise_for_status() # Check the status code\n",
111-
" \n",
110+
"\n",
112111
" audio_buffer = io.BytesIO()\n",
113112
" for chunk in response.iter_content(chunk_size=8192):\n",
114113
" if chunk:\n",
115114
" audio_buffer.write(chunk)\n",
116-
" \n",
115+
"\n",
117116
" audio_buffer.seek(0)\n",
118117
" display(Audio(audio_buffer.getvalue(), autoplay=False))\n",
119118
" print(\"Audio has been loaded into the Notebook and can be played manually\")\n",
@@ -462,7 +461,7 @@
462461
" 'curl -X POST \"http://localhost:8000/v1/audio/speech\" '\n",
463462
" '-H \"Content-Type: application/json\" '\n",
464463
" '-d \\'{\"model\": \"tts-1\", \"input\": \"以下是一些中英文对照的话语。 1. 早上好!希望你有美好的一天。Good morning! Wish you a wonderful day. 2. 你好呀,最近怎么样?Hello there, how have you been recently? 3. 别放弃,你能做到的!Dont give up, you can do it! 4. 继续努力,你的付出会有回报的。Keep up the good work, your efforts will pay off.\", \"voice\": \"echo\", \"response_format\": \"wav\", \"stream\": true}\\' '\n",
465-
" '-s | mpv --no-video -'\n",
464+
" \"-s | mpv --no-video -\"\n",
466465
")\n",
467466
"subprocess.run(cmd, shell=True, check=True)"
468467
]
@@ -1125,7 +1124,7 @@
11251124
" 'curl -X POST \"http://localhost:8000/v1/audio/speech\" '\n",
11261125
" '-H \"Content-Type: application/json\" '\n",
11271126
" '-d \\'{\"model\": \"tts-1\", \"input\": \"以下是一些中英文对照的话语。 1. 早上好!希望你有美好的一天。Good morning! Wish you a wonderful day. 2. 你好呀,最近怎么样?Hello there, how have you been recently? 3. 别放弃,你能做到的!Dont give up, you can do it! 4. 继续努力,你的付出会有回报的。Keep up the good work, your efforts will pay off.\", \"voice\": \"echo\", \"response_format\": \"mp3\", \"stream\": true}\\' '\n",
1128-
" '-s | mpv --no-video -'\n",
1127+
" \"-s | mpv --no-video -\"\n",
11291128
")\n",
11301129
"subprocess.run(cmd, shell=True, check=True)"
11311130
]
@@ -1690,7 +1689,7 @@
16901689
" 'curl -X POST \"http://localhost:8000/v1/audio/speech\" '\n",
16911690
" '-H \"Content-Type: application/json\" '\n",
16921691
" '-d \\'{\"model\": \"tts-1\", \"input\": \"以下是一些中英文对照的话语。 1. 早上好!希望你有美好的一天。Good morning! Wish you a wonderful day. 2. 你好呀,最近怎么样?Hello there, how have you been recently? 3. 别放弃,你能做到的!Dont give up, you can do it! 4. 继续努力,你的付出会有回报的。Keep up the good work, your efforts will pay off.\", \"voice\": \"echo\", \"response_format\": \"ogg\", \"stream\": true}\\' '\n",
1693-
" '-s | mpv --no-video -'\n",
1692+
" \"-s | mpv --no-video -\"\n",
16941693
")\n",
16951694
"subprocess.run(cmd, shell=True, check=True)"
16961695
]

0 commit comments

Comments
 (0)