1+ """
2+ openai_api.py
3+ This module implements a FastAPI-based text-to-speech API compatible with OpenAI's interface specification.
4+
5+ Main features and improvements:
6+ - Use app.state to manage global state, ensuring thread safety
7+ - Add exception handling and unified error responses to improve stability
8+ - Support multiple voice options and audio formats for greater flexibility
9+ - Add input validation to ensure the validity of request parameters
10+ - Support additional OpenAI TTS parameters (e.g., speed) for richer functionality
11+ - Implement health check endpoint for easy service status monitoring
12+ - Use asyncio.Lock to manage model access, improving concurrency performance
13+ - Load and manage speaker embedding files to support personalized speech synthesis
14+ """
15+ import io
16+ import os
17+ import sys
18+ import asyncio
19+ import time
20+ from typing import Optional , Dict
21+ from fastapi import FastAPI , HTTPException
22+ from fastapi .responses import StreamingResponse , JSONResponse
23+ from pydantic import BaseModel , Field
24+ import torch
25+
26+ # Cross-platform compatibility settings
27+ if sys .platform == "darwin" :
28+ os .environ ["PYTORCH_ENABLE_MPS_FALLBACK" ] = "1"
29+
30+ # Set working directory and add to system path
31+ now_dir = os .getcwd ()
32+ sys .path .append (now_dir )
33+
34+ # Import necessary modules
35+ import ChatTTS
36+ from tools .audio import pcm_arr_to_mp3_view , pcm_arr_to_ogg_view , pcm_arr_to_wav_view
37+ from tools .logger import get_logger
38+ from tools .normalizer .en import normalizer_en_nemo_text
39+ from tools .normalizer .zh import normalizer_zh_tn
40+
41+ # Initialize logger
42+ logger = get_logger ("Command" )
43+
44+ # Initialize FastAPI application
45+ app = FastAPI ()
46+
47+ # Voice mapping table
48+ # Download stable voices:
49+ # ModelScope Community: https://modelscope.cn/studios/ttwwwaa/ChatTTS_Speaker
50+ # HuggingFace: https://huggingface.co/spaces/taa/ChatTTS_Speaker
51+ VOICE_MAP = {
52+ "default" : "1528.pt" ,
53+ "alloy" : "1384.pt" ,
54+ "echo" : "2443.pt" ,
55+ }
56+
57+ # Allowed audio formats
58+ ALLOWED_FORMATS = {"mp3" , "wav" , "ogg" }
59+
60+ @app .on_event ("startup" )
61+ async def startup_event ():
62+ """Load ChatTTS model and default speaker embedding when the application starts"""
63+ # Initialize ChatTTS and async lock
64+ app .state .chat = ChatTTS .Chat (get_logger ("ChatTTS" ))
65+ app .state .model_lock = asyncio .Lock () # Use async lock instead of thread lock
66+
67+ # Register text normalizers
68+ app .state .chat .normalizer .register ("en" , normalizer_en_nemo_text ())
69+ app .state .chat .normalizer .register ("zh" , normalizer_zh_tn ())
70+
71+ logger .info ("Initializing ChatTTS..." )
72+ if app .state .chat .load (source = "huggingface" ):
73+ logger .info ("Model loaded successfully." )
74+ else :
75+ logger .error ("Model loading failed, exiting application." )
76+ raise RuntimeError ("Failed to load ChatTTS model" )
77+
78+ # Load default speaker embedding
79+ # Preload all supported speaker embeddings into memory at startup to avoid repeated loading during runtime
80+ app .state .spk_emb_map = {}
81+ for voice , spk_path in VOICE_MAP .items ():
82+ if os .path .exists (spk_path ):
83+ app .state .spk_emb_map [voice ] = torch .load (spk_path , map_location = torch .device ("cpu" ))
84+ logger .info (f"Preloading speaker embedding: { voice } -> { spk_path } " )
85+ else :
86+ logger .warning (f"Speaker embedding not found: { spk_path } , skipping preload" )
87+ app .state .spk_emb = app .state .spk_emb_map .get ("default" ) # Default embedding
88+
89+ # Request parameter whitelist
90+ ALLOWED_PARAMS = {"model" , "input" , "voice" , "response_format" , "speed" , "stream" , "output_format" }
91+
92+ class OpenAITTSRequest (BaseModel ):
93+ """OpenAI TTS request data model"""
94+ model : str = Field (..., description = "Speech synthesis model, fixed as 'tts-1'" )
95+ input : str = Field (..., description = "Text content to synthesize" , max_length = 2048 ) # Length limit
96+ voice : Optional [str ] = Field ("default" , description = "Voice selection, supports: default, alloy, echo" )
97+ response_format : Optional [str ] = Field ("mp3" , description = "Audio format: mp3, wav, ogg" )
98+ speed : Optional [float ] = Field (1.0 , ge = 0.5 , le = 2.0 , description = "Speed, range 0.5-2.0" )
99+ stream : Optional [bool ] = Field (False , description = "Whether to stream" )
100+ output_format : Optional [str ] = "mp3" # Optional formats: mp3, wav, ogg
101+ extra_params : Dict [str , Optional [str ]] = Field (default_factory = dict , description = "Unsupported extra parameters" )
102+
103+ @classmethod
104+ def validate_request (cls , request_data : Dict ):
105+ """Filter unsupported request parameters and unify model value to 'tts-1'"""
106+ request_data ["model" ] = "tts-1" # Unify model value
107+ unsupported_params = set (request_data .keys ()) - ALLOWED_PARAMS
108+ if unsupported_params :
109+ logger .warning (f"Ignoring unsupported parameters: { unsupported_params } " )
110+ return {key : request_data [key ] for key in ALLOWED_PARAMS if key in request_data }
111+
112+ # Unified error response
113+ @app .exception_handler (Exception )
114+ async def custom_exception_handler (request , exc ):
115+ """Custom exception handler"""
116+ logger .error (f"Error: { str (exc )} " )
117+ return JSONResponse (
118+ status_code = getattr (exc , "status_code" , 500 ),
119+ content = {"error" : {"message" : str (exc ), "type" : exc .__class__ .__name__ }}
120+ )
121+
122+ @app .post ("/v1/audio/speech" )
123+ async def generate_voice (request_data : Dict ):
124+ """Handle speech synthesis request"""
125+ request_data = OpenAITTSRequest .validate_request (request_data )
126+ request = OpenAITTSRequest (** request_data )
127+
128+ logger .info (f"Received request: text={ request .input } ..., voice={ request .voice } , stream={ request .stream } " )
129+
130+ # Validate audio format
131+ if request .response_format not in ALLOWED_FORMATS :
132+ raise HTTPException (400 , detail = f"Unsupported audio format: { request .response_format } , supported formats: { ', ' .join (ALLOWED_FORMATS )} " )
133+
134+ # Load speaker embedding for the specified voice
135+ spk_emb = app .state .spk_emb_map .get (request .voice , app .state .spk_emb )
136+
137+ # Inference parameters
138+ params_infer_main = {
139+ "text" : [request .input ],
140+ "stream" : request .stream ,
141+ "lang" : None ,
142+ "skip_refine_text" : True , # Do not use text refinement
143+ "refine_text_only" : False ,
144+ "use_decoder" : True ,
145+ "audio_seed" : 12345678 ,
146+ # "text_seed": 87654321, # Random seed for text processing, used to control text refinement
147+ "do_text_normalization" : True , # Perform text normalization
148+ "do_homophone_replacement" : True , # Perform homophone replacement
149+ }
150+
151+ # Inference code parameters
152+ params_infer_code = app .state .chat .InferCodeParams (
153+ #prompt=f"[speed_{int(request.speed * 10)}]", # Convert to format supported by ChatTTS
154+ prompt = "[speed_5]" ,
155+ top_P = 0.5 ,
156+ top_K = 10 ,
157+ temperature = 0.1 ,
158+ repetition_penalty = 1.1 ,
159+ max_new_token = 2048 ,
160+ min_new_token = 0 ,
161+ show_tqdm = True ,
162+ ensure_non_empty = True ,
163+ manual_seed = 42 ,
164+ spk_emb = spk_emb ,
165+ spk_smp = None ,
166+ txt_smp = None ,
167+ stream_batch = 24 ,
168+ stream_speed = 12000 ,
169+ pass_first_n_batches = 2
170+ )
171+
172+ try :
173+ async with app .state .model_lock :
174+ wavs = app .state .chat .infer (
175+ text = params_infer_main ["text" ],
176+ stream = params_infer_main ["stream" ],
177+ lang = params_infer_main ["lang" ],
178+ skip_refine_text = params_infer_main ["skip_refine_text" ],
179+ use_decoder = params_infer_main ["use_decoder" ],
180+ do_text_normalization = params_infer_main ["do_text_normalization" ],
181+ do_homophone_replacement = params_infer_main ['do_homophone_replacement' ],
182+ # params_refine_text = params_refine_text,
183+ params_infer_code = params_infer_code ,
184+ )
185+ except Exception as e :
186+ raise HTTPException (500 , detail = f"Speech synthesis failed: { str (e )} " )
187+
188+ def generate_wav_header (sample_rate = 24000 , bits_per_sample = 16 , channels = 1 ):
189+ """Generate WAV file header (without data length)"""
190+ header = bytearray ()
191+ header .extend (b"RIFF" )
192+ header .extend (b"\xFF \xFF \xFF \xFF " ) # File size unknown
193+ header .extend (b"WAVEfmt " )
194+ header .extend ((16 ).to_bytes (4 , "little" )) # fmt chunk size
195+ header .extend ((1 ).to_bytes (2 , "little" )) # PCM format
196+ header .extend ((channels ).to_bytes (2 , "little" )) # Channels
197+ header .extend ((sample_rate ).to_bytes (4 , "little" )) # Sample rate
198+ byte_rate = sample_rate * channels * bits_per_sample // 8
199+ header .extend ((byte_rate ).to_bytes (4 , "little" )) # Byte rate
200+ block_align = channels * bits_per_sample // 8
201+ header .extend ((block_align ).to_bytes (2 , "little" )) # Block align
202+ header .extend ((bits_per_sample ).to_bytes (2 , "little" )) # Bits per sample
203+ header .extend (b"data" )
204+ header .extend (b"\xFF \xFF \xFF \xFF " ) # Data size unknown
205+ return bytes (header )
206+
207+ # Handle audio output format
208+ def convert_audio (wav , format ):
209+ """Convert audio format"""
210+ if format == "mp3" :
211+ return pcm_arr_to_mp3_view (wav )
212+ elif format == "wav" :
213+ return pcm_arr_to_wav_view (wav , include_header = False ) # No header in streaming
214+ elif format == "ogg" :
215+ return pcm_arr_to_ogg_view (wav )
216+ return pcm_arr_to_mp3_view (wav )
217+
218+ # Return streaming audio data
219+ if request .stream :
220+ first_chunk = True
221+ async def audio_stream ():
222+ nonlocal first_chunk
223+ for wav in wavs :
224+ if request .response_format == "wav" and first_chunk :
225+ yield generate_wav_header () # Send WAV header
226+ first_chunk = False
227+ yield convert_audio (wav , request .response_format )
228+ media_type = "audio/wav" if request .response_format == "wav" else "audio/mpeg"
229+ return StreamingResponse (audio_stream (), media_type = media_type )
230+
231+ # Return audio file directly
232+ if request .response_format == 'wav' :
233+ music_data = pcm_arr_to_wav_view (wavs [0 ])
234+ else :
235+ music_data = convert_audio (wavs [0 ], request .response_format )
236+
237+ return StreamingResponse (io .BytesIO (music_data ), media_type = "audio/mpeg" , headers = {
238+ "Content-Disposition" : f"attachment; filename=output.{ request .response_format } "
239+ })
240+
241+ @app .get ("/health" )
242+ async def health_check ():
243+ """Health check endpoint"""
244+ return {"status" : "healthy" , "model_loaded" : bool (app .state .chat )}
0 commit comments