Skip to content

Commit 3d69b74

Browse files
committed
add train gpt and sovits api
1 parent 251f184 commit 3d69b74

File tree

1 file changed

+48
-2
lines changed

1 file changed

+48
-2
lines changed

src/rest/rest.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,11 @@
1616
from src.service.audio import AudioService
1717
from src.service.file import FileService
1818
from src.service.namespace import NamespaceService
19+
from src.service.train import TrainGPTService, TrainSovitsService
1920
from src.service.voice import VoiceCloneService
20-
from src.service.session import SessionManager
21+
from src.service.session import SessionManager, session_guard
22+
from src.train.gpt import GPTTrainParams
23+
from src.train.sovits import SovitsTrainParams
2124
from src.utils.response import EaseVoiceResponse
2225
from src.logger import logger
2326

@@ -275,11 +278,48 @@ async def stop_service(self):
275278
raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail={"error": msg})
276279

277280

281+
class TrainAPI:
282+
def __init__(self):
283+
self.router = APIRouter()
284+
self._register_routes()
285+
286+
def _register_routes(self):
287+
self.router.post("/train/gpt")(self.train_gpt)
288+
self.router.post("/train/sovits")(self.train_sovits)
289+
290+
def train_gpt(self, params: GPTTrainParams):
291+
result = self._do_train_gpt(params)
292+
293+
# session_guard wrapper return a dict
294+
if isinstance(result, EaseVoiceResponse):
295+
return result
296+
logger.error(f"failed to train gpt: {result}")
297+
raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail=result)
298+
299+
async def train_sovits(self, params: SovitsTrainParams):
300+
result = self._do_train_sovits(params)
301+
302+
# session_guard wrapper return a dict
303+
if isinstance(result, EaseVoiceResponse):
304+
return result
305+
logger.error(f"failed to train sovits: {result}")
306+
raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail=result)
307+
308+
@session_guard("TrainGPT")
309+
def _do_train_gpt(self, params: GPTTrainParams):
310+
service = TrainGPTService(params)
311+
return service.train()
312+
313+
@session_guard("TrainSovits")
314+
def _do_train_sovits(self, params: SovitsTrainParams):
315+
service = TrainSovitsService(params)
316+
return service.train()
317+
318+
278319
# Initialize FastAPI and NamespaceService
279320
app = FastAPI()
280321

281322
namespace_service = NamespaceService()
282-
voice_service = VoiceCloneService()
283323
namespace_api = NamespaceAPI(namespace_service)
284324
app.include_router(namespace_api.router, prefix="/apis/v1")
285325

@@ -290,3 +330,9 @@ async def stop_service(self):
290330
session_manager = SessionManager()
291331
session_api = SessionAPI(session_manager)
292332
app.include_router(session_api.router, prefix="/apis/v1")
333+
334+
voice_clone_api = VoiceCloneAPI(session_manager)
335+
app.include_router(voice_clone_api.router, prefix="/apis/v1")
336+
337+
train_api = TrainAPI()
338+
app.include_router(train_api.router, prefix="/apis/v1")

0 commit comments

Comments
 (0)