1212- Use asyncio.Lock to manage model access, improving concurrency performance
1313- Load and manage speaker embedding files to support personalized speech synthesis
1414"""
15+
1516import io
1617import os
1718import sys
1819import asyncio
19- import time
20+ import time
2021from typing import Optional , Dict
2122from fastapi import FastAPI , HTTPException
2223from fastapi .responses import StreamingResponse , JSONResponse
5758# Allowed audio formats
5859ALLOWED_FORMATS = {"mp3" , "wav" , "ogg" }
5960
61+
6062@app .on_event ("startup" )
6163async def startup_event ():
6264 """Load ChatTTS model and default speaker embedding when the application starts"""
6365 # Initialize ChatTTS and async lock
6466 app .state .chat = ChatTTS .Chat (get_logger ("ChatTTS" ))
6567 app .state .model_lock = asyncio .Lock () # Use async lock instead of thread lock
66-
68+
6769 # Register text normalizers
6870 app .state .chat .normalizer .register ("en" , normalizer_en_nemo_text ())
6971 app .state .chat .normalizer .register ("zh" , normalizer_zh_tn ())
70-
72+
7173 logger .info ("Initializing ChatTTS..." )
7274 if app .state .chat .load (source = "huggingface" ):
7375 logger .info ("Model loaded successfully." )
7476 else :
7577 logger .error ("Model loading failed, exiting application." )
7678 raise RuntimeError ("Failed to load ChatTTS model" )
77-
79+
7880 # Load default speaker embedding
7981 # Preload all supported speaker embeddings into memory at startup to avoid repeated loading during runtime
8082 app .state .spk_emb_map = {}
8183 for voice , spk_path in VOICE_MAP .items ():
8284 if os .path .exists (spk_path ):
83- app .state .spk_emb_map [voice ] = torch .load (spk_path , map_location = torch .device ("cpu" ))
85+ app .state .spk_emb_map [voice ] = torch .load (
86+ spk_path , map_location = torch .device ("cpu" )
87+ )
8488 logger .info (f"Preloading speaker embedding: { voice } -> { spk_path } " )
8589 else :
8690 logger .warning (f"Speaker embedding not found: { spk_path } , skipping preload" )
8791 app .state .spk_emb = app .state .spk_emb_map .get ("default" ) # Default embedding
8892
93+
8994# Request parameter whitelist
90- ALLOWED_PARAMS = {"model" , "input" , "voice" , "response_format" , "speed" , "stream" , "output_format" }
95+ ALLOWED_PARAMS = {
96+ "model" ,
97+ "input" ,
98+ "voice" ,
99+ "response_format" ,
100+ "speed" ,
101+ "stream" ,
102+ "output_format" ,
103+ }
104+
91105
92106class OpenAITTSRequest (BaseModel ):
93107 """OpenAI TTS request data model"""
108+
94109 model : str = Field (..., description = "Speech synthesis model, fixed as 'tts-1'" )
95- input : str = Field (..., description = "Text content to synthesize" , max_length = 2048 ) # Length limit
96- voice : Optional [str ] = Field ("default" , description = "Voice selection, supports: default, alloy, echo" )
97- response_format : Optional [str ] = Field ("mp3" , description = "Audio format: mp3, wav, ogg" )
98- speed : Optional [float ] = Field (1.0 , ge = 0.5 , le = 2.0 , description = "Speed, range 0.5-2.0" )
110+ input : str = Field (
111+ ..., description = "Text content to synthesize" , max_length = 2048
112+ ) # Length limit
113+ voice : Optional [str ] = Field (
114+ "default" , description = "Voice selection, supports: default, alloy, echo"
115+ )
116+ response_format : Optional [str ] = Field (
117+ "mp3" , description = "Audio format: mp3, wav, ogg"
118+ )
119+ speed : Optional [float ] = Field (
120+ 1.0 , ge = 0.5 , le = 2.0 , description = "Speed, range 0.5-2.0"
121+ )
99122 stream : Optional [bool ] = Field (False , description = "Whether to stream" )
100123 output_format : Optional [str ] = "mp3" # Optional formats: mp3, wav, ogg
101- extra_params : Dict [str , Optional [str ]] = Field (default_factory = dict , description = "Unsupported extra parameters" )
124+ extra_params : Dict [str , Optional [str ]] = Field (
125+ default_factory = dict , description = "Unsupported extra parameters"
126+ )
102127
103128 @classmethod
104129 def validate_request (cls , request_data : Dict ):
@@ -109,31 +134,38 @@ def validate_request(cls, request_data: Dict):
109134 logger .warning (f"Ignoring unsupported parameters: { unsupported_params } " )
110135 return {key : request_data [key ] for key in ALLOWED_PARAMS if key in request_data }
111136
137+
112138# Unified error response
113139@app .exception_handler (Exception )
114140async def custom_exception_handler (request , exc ):
115141 """Custom exception handler"""
116142 logger .error (f"Error: { str (exc )} " )
117143 return JSONResponse (
118144 status_code = getattr (exc , "status_code" , 500 ),
119- content = {"error" : {"message" : str (exc ), "type" : exc .__class__ .__name__ }}
145+ content = {"error" : {"message" : str (exc ), "type" : exc .__class__ .__name__ }},
120146 )
121147
148+
122149@app .post ("/v1/audio/speech" )
123150async def generate_voice (request_data : Dict ):
124151 """Handle speech synthesis request"""
125152 request_data = OpenAITTSRequest .validate_request (request_data )
126153 request = OpenAITTSRequest (** request_data )
127-
128- logger .info (f"Received request: text={ request .input } ..., voice={ request .voice } , stream={ request .stream } " )
129-
154+
155+ logger .info (
156+ f"Received request: text={ request .input } ..., voice={ request .voice } , stream={ request .stream } "
157+ )
158+
130159 # Validate audio format
131160 if request .response_format not in ALLOWED_FORMATS :
132- raise HTTPException (400 , detail = f"Unsupported audio format: { request .response_format } , supported formats: { ', ' .join (ALLOWED_FORMATS )} " )
161+ raise HTTPException (
162+ 400 ,
163+ detail = f"Unsupported audio format: { request .response_format } , supported formats: { ', ' .join (ALLOWED_FORMATS )} " ,
164+ )
133165
134166 # Load speaker embedding for the specified voice
135167 spk_emb = app .state .spk_emb_map .get (request .voice , app .state .spk_emb )
136-
168+
137169 # Inference parameters
138170 params_infer_main = {
139171 "text" : [request .input ],
@@ -145,13 +177,13 @@ async def generate_voice(request_data: Dict):
145177 "audio_seed" : 12345678 ,
146178 # "text_seed": 87654321, # Random seed for text processing, used to control text refinement
147179 "do_text_normalization" : True , # Perform text normalization
148- "do_homophone_replacement" : True , # Perform homophone replacement
180+ "do_homophone_replacement" : True , # Perform homophone replacement
149181 }
150-
182+
151183 # Inference code parameters
152184 params_infer_code = app .state .chat .InferCodeParams (
153- #prompt=f"[speed_{int(request.speed * 10)}]", # Convert to format supported by ChatTTS
154- prompt = "[speed_5]" ,
185+ # prompt=f"[speed_{int(request.speed * 10)}]", # Convert to format supported by ChatTTS
186+ prompt = "[speed_5]" ,
155187 top_P = 0.5 ,
156188 top_K = 10 ,
157189 temperature = 0.1 ,
@@ -166,21 +198,21 @@ async def generate_voice(request_data: Dict):
166198 txt_smp = None ,
167199 stream_batch = 24 ,
168200 stream_speed = 12000 ,
169- pass_first_n_batches = 2
201+ pass_first_n_batches = 2 ,
170202 )
171203
172204 try :
173205 async with app .state .model_lock :
174206 wavs = app .state .chat .infer (
175- text = params_infer_main ["text" ],
176- stream = params_infer_main ["stream" ],
177- lang = params_infer_main ["lang" ],
178- skip_refine_text = params_infer_main ["skip_refine_text" ],
179- use_decoder = params_infer_main ["use_decoder" ],
180- do_text_normalization = params_infer_main ["do_text_normalization" ],
181- do_homophone_replacement = params_infer_main [' do_homophone_replacement' ],
182- # params_refine_text = params_refine_text,
183- params_infer_code = params_infer_code ,
207+ text = params_infer_main ["text" ],
208+ stream = params_infer_main ["stream" ],
209+ lang = params_infer_main ["lang" ],
210+ skip_refine_text = params_infer_main ["skip_refine_text" ],
211+ use_decoder = params_infer_main ["use_decoder" ],
212+ do_text_normalization = params_infer_main ["do_text_normalization" ],
213+ do_homophone_replacement = params_infer_main [" do_homophone_replacement" ],
214+ # params_refine_text = params_refine_text,
215+ params_infer_code = params_infer_code ,
184216 )
185217 except Exception as e :
186218 raise HTTPException (500 , detail = f"Speech synthesis failed: { str (e )} " )
@@ -189,7 +221,7 @@ def generate_wav_header(sample_rate=24000, bits_per_sample=16, channels=1):
189221 """Generate WAV file header (without data length)"""
190222 header = bytearray ()
191223 header .extend (b"RIFF" )
192- header .extend (b"\xFF \xFF \xFF \xFF " ) # File size unknown
224+ header .extend (b"\xff \xff \xff \xff " ) # File size unknown
193225 header .extend (b"WAVEfmt " )
194226 header .extend ((16 ).to_bytes (4 , "little" )) # fmt chunk size
195227 header .extend ((1 ).to_bytes (2 , "little" )) # PCM format
@@ -201,7 +233,7 @@ def generate_wav_header(sample_rate=24000, bits_per_sample=16, channels=1):
201233 header .extend ((block_align ).to_bytes (2 , "little" )) # Block align
202234 header .extend ((bits_per_sample ).to_bytes (2 , "little" )) # Bits per sample
203235 header .extend (b"data" )
204- header .extend (b"\xFF \xFF \xFF \xFF " ) # Data size unknown
236+ header .extend (b"\xff \xff \xff \xff " ) # Data size unknown
205237 return bytes (header )
206238
207239 # Handle audio output format
@@ -210,35 +242,44 @@ def convert_audio(wav, format):
210242 if format == "mp3" :
211243 return pcm_arr_to_mp3_view (wav )
212244 elif format == "wav" :
213- return pcm_arr_to_wav_view (wav , include_header = False ) # No header in streaming
245+ return pcm_arr_to_wav_view (
246+ wav , include_header = False
247+ ) # No header in streaming
214248 elif format == "ogg" :
215249 return pcm_arr_to_ogg_view (wav )
216- return pcm_arr_to_mp3_view (wav )
217-
250+ return pcm_arr_to_mp3_view (wav )
251+
218252 # Return streaming audio data
219253 if request .stream :
220254 first_chunk = True
255+
221256 async def audio_stream ():
222257 nonlocal first_chunk
223258 for wav in wavs :
224259 if request .response_format == "wav" and first_chunk :
225260 yield generate_wav_header () # Send WAV header
226261 first_chunk = False
227262 yield convert_audio (wav , request .response_format )
263+
228264 media_type = "audio/wav" if request .response_format == "wav" else "audio/mpeg"
229265 return StreamingResponse (audio_stream (), media_type = media_type )
230-
266+
231267 # Return audio file directly
232- if request .response_format == ' wav' :
268+ if request .response_format == " wav" :
233269 music_data = pcm_arr_to_wav_view (wavs [0 ])
234270 else :
235271 music_data = convert_audio (wavs [0 ], request .response_format )
236-
237- return StreamingResponse (io .BytesIO (music_data ), media_type = "audio/mpeg" , headers = {
238- "Content-Disposition" : f"attachment; filename=output.{ request .response_format } "
239- })
272+
273+ return StreamingResponse (
274+ io .BytesIO (music_data ),
275+ media_type = "audio/mpeg" ,
276+ headers = {
277+ "Content-Disposition" : f"attachment; filename=output.{ request .response_format } "
278+ },
279+ )
280+
240281
241282@app .get ("/health" )
242283async def health_check ():
243284 """Health check endpoint"""
244- return {"status" : "healthy" , "model_loaded" : bool (app .state .chat )}
285+ return {"status" : "healthy" , "model_loaded" : bool (app .state .chat )}
0 commit comments