-
-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathhandler.py
More file actions
300 lines (265 loc) · 12 KB
/
handler.py
File metadata and controls
300 lines (265 loc) · 12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
"""
RunPod Serverless Handler Wrapper for Kokoro FastAPI
This handler starts the existing FastAPI app internally and proxies requests to it.
"""
import asyncio
import base64
import json
import subprocess
import time
import threading
import os
import sys
from typing import Dict, Any
import requests
import runpod
from loguru import logger
# Global variables
fastapi_process = None
fastapi_ready = False
FASTAPI_URL = "http://localhost:8880"
def start_fastapi():
"""Start the FastAPI server in the background"""
global fastapi_process, fastapi_ready
logger.info("Starting internal FastAPI server...")
# Start the FastAPI server using the existing startup method
# Try entrypoint.sh first, then fallback to direct uvicorn
method = "entrypoint"
try:
if os.path.exists("/app/entrypoint.sh"):
logger.info("Found /app/entrypoint.sh, attempting to execute...")
# Make sure it's executable
os.chmod("/app/entrypoint.sh", 0o755)
fastapi_process = subprocess.Popen(
["/app/entrypoint.sh"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True
)
else:
raise FileNotFoundError
except (FileNotFoundError, PermissionError):
method = "direct_uvicorn"
logger.info("entrypoint.sh not found or failed, falling back to direct uvicorn...")
fastapi_process = subprocess.Popen([
sys.executable, "-m", "uvicorn",
"api.src.main:app",
"--host", "0.0.0.0",
"--port", "8880"
], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, cwd="/app")
# Wait for the server to be ready
max_wait = 300 # 5 minutes max wait for model download
start_time = time.time()
while time.time() - start_time < max_wait:
# Check if process is still running
if fastapi_process.poll() is not None:
logger.error(f"FastAPI process exited unexpectedly with code {fastapi_process.returncode}")
stdout, stderr = fastapi_process.communicate()
logger.error(f"STDOUT: {stdout}")
logger.error(f"STDERR: {stderr}")
raise RuntimeError(f"FastAPI server crashed on startup ({method})")
try:
response = requests.get(f"{FASTAPI_URL}/health", timeout=5)
if response.status_code == 200:
fastapi_ready = True
logger.info("FastAPI server is ready!")
return
except requests.exceptions.RequestException:
pass
time.sleep(2)
logger.error("FastAPI server failed to start within timeout")
# If we timed out, print the logs anyway to see what's happening
if fastapi_process:
logger.info("Dumping process logs due to timeout...")
fastapi_process.kill()
stdout, stderr = fastapi_process.communicate()
logger.info(f"STDOUT: {stdout}")
logger.info(f"STDERR: {stderr}")
raise RuntimeError("FastAPI server startup timeout")
def wait_for_fastapi():
"""Wait for FastAPI to be ready"""
max_wait = 300 # 5 minutes
start_time = time.time()
while not fastapi_ready and time.time() - start_time < max_wait:
time.sleep(1)
if not fastapi_ready:
raise RuntimeError("FastAPI server not ready")
def handler(job: Dict[str, Any]) -> Dict[str, Any]:
"""
RunPod serverless handler that proxies to the internal FastAPI server
Supports all endpoints from the original Kokoro FastAPI:
- /v1/audio/speech (TTS generation)
- /v1/audio/voices (list voices)
- /v1/models (list models)
- /dev/captioned_speech (TTS with timestamps)
- /dev/phonemize (text to phonemes)
- /dev/generate_from_phonemes (phonemes to audio)
- /v1/audio/voices/combine (voice combination)
"""
try:
# Ensure FastAPI is ready
wait_for_fastapi()
job_input = job.get("input", {})
# Determine endpoint and method from input
endpoint = job_input.get("endpoint", "/v1/audio/speech")
method = job_input.get("method", "POST").upper()
# Handle different endpoints
if endpoint == "/v1/audio/voices" and method == "GET":
# List voices
response = requests.get(f"{FASTAPI_URL}/v1/audio/voices", timeout=30)
if response.status_code == 200:
return {"success": True, "voices": response.json()}
else:
return {"success": False, "error": f"Failed to get voices: {response.text}"}
elif endpoint == "/v1/models" and method == "GET":
# List models
response = requests.get(f"{FASTAPI_URL}/v1/models", timeout=30)
if response.status_code == 200:
return {"success": True, "models": response.json()}
else:
return {"success": False, "error": f"Failed to get models: {response.text}"}
elif endpoint == "/dev/phonemize" and method == "POST":
# Phonemize text
payload = {
"text": job_input.get("text", ""),
"language": job_input.get("language", "a")
}
response = requests.post(f"{FASTAPI_URL}/dev/phonemize", json=payload, timeout=60)
if response.status_code == 200:
return {"success": True, "result": response.json()}
else:
return {"success": False, "error": f"Phonemize failed: {response.text}"}
elif endpoint == "/dev/generate_from_phonemes" and method == "POST":
# Generate from phonemes
payload = {
"phonemes": job_input.get("phonemes", ""),
"voice": job_input.get("voice", "af_bella")
}
response = requests.post(f"{FASTAPI_URL}/dev/generate_from_phonemes", json=payload, timeout=300)
if response.status_code == 200:
audio_data = response.content
audio_base64 = base64.b64encode(audio_data).decode('utf-8')
return {
"success": True,
"audio_base64": audio_base64,
"voice": payload["voice"],
"size_bytes": len(audio_data)
}
else:
return {"success": False, "error": f"Phoneme generation failed: {response.text}"}
elif endpoint == "/v1/audio/voices/combine" and method == "POST":
# Combine voices
voices = job_input.get("voices", "")
response = requests.post(f"{FASTAPI_URL}/v1/audio/voices/combine", json=voices, timeout=120)
if response.status_code == 200:
# Voice combination returns a file - encode as base64
file_data = response.content
file_base64 = base64.b64encode(file_data).decode('utf-8')
return {
"success": True,
"voice_file_base64": file_base64,
"voices": voices,
"size_bytes": len(file_data)
}
else:
return {"success": False, "error": f"Voice combination failed: {response.text}"}
elif endpoint == "/dev/captioned_speech" and method == "POST":
# Captioned speech with timestamps
fastapi_payload = job_input.copy()
fastapi_payload.pop("endpoint", None)
fastapi_payload.pop("method", None)
# Set defaults for captioned speech
if "input" not in fastapi_payload:
fastapi_payload["input"] = job_input.get("text", "")
if "model" not in fastapi_payload:
fastapi_payload["model"] = "kokoro"
if "voice" not in fastapi_payload:
fastapi_payload["voice"] = "af_bella"
if "response_format" not in fastapi_payload:
fastapi_payload["response_format"] = "mp3"
response = requests.post(f"{FASTAPI_URL}/dev/captioned_speech", json=fastapi_payload, timeout=300)
if response.status_code == 200:
result = response.json()
# If audio is in the response, encode it
if "audio" in result:
# Audio is already base64 encoded in captioned speech response
return {"success": True, "result": result}
else:
return {"success": True, "result": result}
else:
return {"success": False, "error": f"Captioned speech failed: {response.text}"}
else:
# Default: /v1/audio/speech endpoint
# Handle both OpenAI format and simple format for speech generation
if "input" in job_input or "text" in job_input:
if "input" in job_input and isinstance(job_input["input"], str):
# OpenAI-compatible format - pass through directly
fastapi_payload = job_input.copy()
fastapi_payload.pop("endpoint", None)
fastapi_payload.pop("method", None)
elif "text" in job_input:
# Simple format - convert to OpenAI format
text = job_input.get("text")
fastapi_payload = {
"model": job_input.get("model", "kokoro"),
"input": text,
"voice": job_input.get("voice", "af_bella"),
"response_format": job_input.get("format", job_input.get("response_format", "mp3")),
"speed": job_input.get("speed", 1.0)
}
# Copy other optional parameters
for key in ["stream", "return_download_link", "lang_code", "normalization_options"]:
if key in job_input:
fastapi_payload[key] = job_input[key]
else:
return {"error": "Missing required parameter: 'input' or 'text'"}
else:
return {"error": "Missing required parameter: 'input' or 'text'"}
logger.info(f"Forwarding TTS request: {fastapi_payload.get('input', '')[:50]}...")
# Forward request to internal FastAPI server
response = requests.post(
f"{FASTAPI_URL}/v1/audio/speech",
json=fastapi_payload,
timeout=300, # 5 minutes timeout
headers={"Content-Type": "application/json"}
)
if response.status_code == 200:
# Convert binary audio response to base64
audio_data = response.content
audio_base64 = base64.b64encode(audio_data).decode('utf-8')
return {
"success": True,
"audio_base64": audio_base64,
"text": fastapi_payload.get("input", ""),
"voice": fastapi_payload.get("voice", ""),
"speed": fastapi_payload.get("speed", 1.0),
"format": fastapi_payload.get("response_format", "mp3"),
"model": fastapi_payload.get("model", "kokoro"),
"size_bytes": len(audio_data)
}
else:
error_msg = f"FastAPI error: {response.status_code} - {response.text}"
logger.error(error_msg)
return {
"success": False,
"error": error_msg
}
except requests.exceptions.Timeout:
logger.error("Request timeout")
return {
"success": False,
"error": "Request timeout - request took too long"
}
except Exception as e:
logger.error(f"Handler error: {e}")
return {
"success": False,
"error": str(e)
}
# Start FastAPI server in background thread
logger.info("Initializing Kokoro FastAPI Serverless Wrapper...")
threading.Thread(target=start_fastapi, daemon=True).start()
# Start the serverless worker
if __name__ == "__main__":
logger.info("Starting RunPod serverless worker...")
runpod.serverless.start({"handler": handler})