|
1 | 1 | import asyncio |
2 | 2 | import logging |
| 3 | +from http import HTTPStatus |
3 | 4 |
|
4 | | -from fastapi import APIRouter, WebSocket |
5 | | -from nexent.core.models.stt_model import STTConfig, STTModel |
6 | | -from nexent.core.models.tts_model import TTSConfig, TTSModel |
| 5 | +from fastapi import APIRouter, WebSocket, HTTPException, Body, Query |
| 6 | +from fastapi.responses import JSONResponse |
7 | 7 |
|
8 | | -from consts.const import APPID, CLUSTER, SPEED_RATIO, TEST_VOICE_PATH, TOKEN, VOICE_TYPE |
| 8 | +from consts.exceptions import ( |
| 9 | + VoiceServiceException, |
| 10 | + STTConnectionException, |
| 11 | + TTSConnectionException, |
| 12 | + VoiceConfigException |
| 13 | +) |
| 14 | +from consts.model import VoiceConnectivityRequest, VoiceConnectivityResponse |
| 15 | +from services.voice_service import get_voice_service |
9 | 16 |
|
10 | 17 | logger = logging.getLogger("voice_app") |
11 | 18 |
|
12 | | - |
13 | | -class VoiceService: |
14 | | - """Unified voice service that hosts both STT and TTS on a single FastAPI application""" |
15 | | - |
16 | | - def __init__(self): |
17 | | - """ |
18 | | - Initialize the voice service with configurations from const.py. |
19 | | - """ |
20 | | - # Initialize STT configuration |
21 | | - self.stt_config = STTConfig( |
22 | | - appid=APPID, |
23 | | - token=TOKEN |
| 19 | +router = APIRouter(prefix="/voice") |
| 20 | + |
| 21 | + |
| 22 | +@router.websocket("/stt/ws") |
| 23 | +async def stt_websocket(websocket: WebSocket): |
| 24 | + """WebSocket endpoint for real-time audio streaming and STT""" |
| 25 | + logger.info("STT WebSocket connection attempt...") |
| 26 | + await websocket.accept() |
| 27 | + logger.info("STT WebSocket connection accepted") |
| 28 | + |
| 29 | + try: |
| 30 | + voice_service = get_voice_service() |
| 31 | + await voice_service.start_stt_streaming_session(websocket) |
| 32 | + except STTConnectionException as e: |
| 33 | + logger.error(f"STT WebSocket error: {str(e)}") |
| 34 | + await websocket.send_json({"error": str(e)}) |
| 35 | + except Exception as e: |
| 36 | + logger.error(f"STT WebSocket error: {str(e)}") |
| 37 | + await websocket.send_json({"error": str(e)}) |
| 38 | + finally: |
| 39 | + logger.info("STT WebSocket connection closed") |
| 40 | + |
| 41 | + |
| 42 | +@router.websocket("/tts/ws") |
| 43 | +async def tts_websocket(websocket: WebSocket): |
| 44 | + """WebSocket endpoint for streaming TTS""" |
| 45 | + logger.info("TTS WebSocket connection attempt...") |
| 46 | + await websocket.accept() |
| 47 | + logger.info("TTS WebSocket connection accepted") |
| 48 | + |
| 49 | + try: |
| 50 | + # Receive text from client (single request) |
| 51 | + data = await websocket.receive_json() |
| 52 | + text = data.get("text") |
| 53 | + |
| 54 | + if not text: |
| 55 | + if websocket.client_state.name == "CONNECTED": |
| 56 | + await websocket.send_json({"error": "No text provided"}) |
| 57 | + return |
| 58 | + |
| 59 | + # Stream TTS audio to WebSocket |
| 60 | + voice_service = get_voice_service() |
| 61 | + await voice_service.stream_tts_to_websocket(websocket, text) |
| 62 | + |
| 63 | + except TTSConnectionException as e: |
| 64 | + logger.error(f"TTS WebSocket error: {str(e)}") |
| 65 | + await websocket.send_json({"error": str(e)}) |
| 66 | + except Exception as e: |
| 67 | + logger.error(f"TTS WebSocket error: {str(e)}") |
| 68 | + await websocket.send_json({"error": str(e)}) |
| 69 | + finally: |
| 70 | + logger.info("TTS WebSocket connection closed") |
| 71 | + # Ensure connection is properly closed |
| 72 | + if websocket.client_state.name == "CONNECTED": |
| 73 | + await websocket.close() |
| 74 | + |
| 75 | + |
| 76 | +@router.post("/connectivity") |
| 77 | +async def check_voice_connectivity(request: VoiceConnectivityRequest): |
| 78 | + """ |
| 79 | + Check voice service connectivity |
| 80 | + |
| 81 | + Args: |
| 82 | + request: VoiceConnectivityRequest containing model_type |
| 83 | + |
| 84 | + Returns: |
| 85 | + VoiceConnectivityResponse with connectivity status |
| 86 | + """ |
| 87 | + try: |
| 88 | + voice_service = get_voice_service() |
| 89 | + connected = await voice_service.check_voice_connectivity(request.model_type) |
| 90 | + |
| 91 | + return JSONResponse( |
| 92 | + status_code=HTTPStatus.OK, |
| 93 | + content=VoiceConnectivityResponse( |
| 94 | + connected=connected, |
| 95 | + model_type=request.model_type, |
| 96 | + message="Service is connected" if connected else "Service connection failed" |
| 97 | + ).dict() |
24 | 98 | ) |
25 | | - |
26 | | - # Initialize TTS configuration |
27 | | - self.tts_config = TTSConfig( |
28 | | - appid=APPID, |
29 | | - token=TOKEN, |
30 | | - cluster=CLUSTER, |
31 | | - voice_type=VOICE_TYPE, |
32 | | - speed_ratio=SPEED_RATIO |
| 99 | + except VoiceServiceException as e: |
| 100 | + logger.error(f"Voice service error: {str(e)}") |
| 101 | + raise HTTPException( |
| 102 | + status_code=HTTPStatus.BAD_REQUEST, |
| 103 | + detail=str(e) |
| 104 | + ) |
| 105 | + except (STTConnectionException, TTSConnectionException) as e: |
| 106 | + logger.error(f"Voice connectivity error: {str(e)}") |
| 107 | + raise HTTPException( |
| 108 | + status_code=HTTPStatus.SERVICE_UNAVAILABLE, |
| 109 | + detail=str(e) |
| 110 | + ) |
| 111 | + except VoiceConfigException as e: |
| 112 | + logger.error(f"Voice configuration error: {str(e)}") |
| 113 | + raise HTTPException( |
| 114 | + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, |
| 115 | + detail=str(e) |
| 116 | + ) |
| 117 | + except Exception as e: |
| 118 | + logger.error(f"Unexpected voice service error: {str(e)}") |
| 119 | + raise HTTPException( |
| 120 | + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, |
| 121 | + detail="Voice service error" |
33 | 122 | ) |
34 | | - |
35 | | - # Initialize models |
36 | | - self.stt_model = STTModel(self.stt_config, TEST_VOICE_PATH) |
37 | | - self.tts_model = TTSModel(self.tts_config) |
38 | | - |
39 | | - # Create FastAPI application |
40 | | - self.router = APIRouter(prefix="/voice") |
41 | | - |
42 | | - # Set up routes |
43 | | - self._setup_routes() |
44 | | - |
45 | | - def _setup_routes(self): |
46 | | - """Configure API routes for voice services""" |
47 | | - |
48 | | - # STT WebSocket route |
49 | | - @self.router.websocket("/stt/ws") |
50 | | - async def stt_websocket(websocket: WebSocket): |
51 | | - """WebSocket endpoint for real-time audio streaming and STT""" |
52 | | - logger.info("STT WebSocket connection attempt...") |
53 | | - await websocket.accept() |
54 | | - logger.info("STT WebSocket connection accepted") |
55 | | - try: |
56 | | - # Start streaming session |
57 | | - await self.stt_model.start_streaming_session(websocket) |
58 | | - except Exception as e: |
59 | | - logger.error(f"STT WebSocket error: {str(e)}") |
60 | | - import traceback |
61 | | - traceback.print_exc() |
62 | | - await websocket.send_json({"error": str(e)}) |
63 | | - finally: |
64 | | - logger.info("STT WebSocket connection closed") |
65 | | - |
66 | | - # TTS WebSocket route |
67 | | - @self.router.websocket("/tts/ws") |
68 | | - async def tts_websocket(websocket: WebSocket): |
69 | | - """WebSocket endpoint for streaming TTS""" |
70 | | - logger.info("TTS WebSocket connection attempt...") |
71 | | - await websocket.accept() |
72 | | - logger.info("TTS WebSocket connection accepted") |
73 | | - |
74 | | - try: |
75 | | - # Receive text from client (single request) |
76 | | - data = await websocket.receive_json() |
77 | | - text = data.get("text") |
78 | | - |
79 | | - if not text: |
80 | | - if websocket.client_state.name == "CONNECTED": |
81 | | - await websocket.send_json({"error": "No text provided"}) |
82 | | - return |
83 | | - |
84 | | - # Generate and stream audio chunks |
85 | | - try: |
86 | | - # First try to use it as a coroutine that returns an async iterator |
87 | | - speech_result = await self.tts_model.generate_speech(text, stream=True) |
88 | | - |
89 | | - # Check if it's an async iterator or a regular iterable |
90 | | - if hasattr(speech_result, '__aiter__'): |
91 | | - # It's an async iterator, use async for |
92 | | - async for chunk in speech_result: |
93 | | - if websocket.client_state.name == "CONNECTED": |
94 | | - await websocket.send_bytes(chunk) |
95 | | - else: |
96 | | - break |
97 | | - elif hasattr(speech_result, '__iter__'): |
98 | | - # It's a regular iterator, use normal for |
99 | | - for chunk in speech_result: |
100 | | - if websocket.client_state.name == "CONNECTED": |
101 | | - await websocket.send_bytes(chunk) |
102 | | - else: |
103 | | - break |
104 | | - else: |
105 | | - # It's a single chunk, send it directly |
106 | | - if websocket.client_state.name == "CONNECTED": |
107 | | - await websocket.send_bytes(speech_result) |
108 | | - |
109 | | - await asyncio.sleep(0.1) |
110 | | - |
111 | | - except TypeError as te: |
112 | | - # If speech_result is still a coroutine, try calling it directly without stream=True |
113 | | - if "async for" in str(te) and "requires an object with __aiter__" in str(te): |
114 | | - logger.error("Falling back to non-streaming TTS") |
115 | | - speech_data = await self.tts_model.generate_speech(text, stream=False) |
116 | | - if websocket.client_state.name == "CONNECTED": |
117 | | - await websocket.send_bytes(speech_data) |
118 | | - else: |
119 | | - raise |
120 | | - |
121 | | - # Send end marker after successful TTS generation |
122 | | - if websocket.client_state.name == "CONNECTED": |
123 | | - await websocket.send_json({"status": "completed"}) |
124 | | - |
125 | | - except Exception as e: |
126 | | - logger.error(f"TTS WebSocket error: {str(e)}") |
127 | | - import traceback |
128 | | - traceback.print_exc() |
129 | | - await websocket.send_json({"error": str(e)}) |
130 | | - finally: |
131 | | - logger.info("TTS WebSocket connection closed") |
132 | | - # Ensure connection is properly closed |
133 | | - if websocket.client_state.name == "CONNECTED": |
134 | | - await websocket.close() |
135 | | - |
136 | | - async def check_connectivity(self, model_type: str) -> bool: |
137 | | - """ |
138 | | - Check the connectivity status of voice services (STT and TTS) |
139 | | -
|
140 | | - Args: |
141 | | - model_type: The type of model to check, options are 'stt', 'tts' |
142 | | -
|
143 | | - Returns: |
144 | | - bool: Returns True if all services are connected normally, False if any service connection fails |
145 | | - """ |
146 | | - try: |
147 | | - stt_connected = False |
148 | | - tts_connected = False |
149 | | - |
150 | | - if model_type == 'stt': |
151 | | - logging.info(f'STT Config: {self.stt_config}') |
152 | | - stt_connected = await self.stt_model.check_connectivity() |
153 | | - if not stt_connected: |
154 | | - logging.error( |
155 | | - "Speech Recognition (STT) service connection failed") |
156 | | - |
157 | | - if model_type == 'tts': |
158 | | - logging.info(f'TTS Config: {self.tts_config}') |
159 | | - tts_connected = await self.tts_model.check_connectivity() |
160 | | - if not tts_connected: |
161 | | - logging.error( |
162 | | - "Text-to-Speech (TTS) service connection failed") |
163 | | - |
164 | | - # Return the corresponding connection status based on model_type |
165 | | - if model_type == 'stt': |
166 | | - return stt_connected |
167 | | - elif model_type == 'tts': |
168 | | - return tts_connected |
169 | | - else: |
170 | | - logging.error(f"Unknown model type: {model_type}") |
171 | | - return False |
172 | | - |
173 | | - except Exception as e: |
174 | | - logging.error( |
175 | | - f"Voice service connectivity test encountered an exception: {str(e)}") |
176 | | - return False |
177 | | - |
178 | | - |
179 | | -router = VoiceService().router |
0 commit comments