1616from src .service .audio import AudioService
1717from src .service .file import FileService
1818from src .service .namespace import NamespaceService
19+ from src .service .train import TrainGPTService , TrainSovitsService
1920from src .service .voice import VoiceCloneService
20- from src .service .session import SessionManager
21+ from src .service .session import SessionManager , session_guard
22+ from src .train .gpt import GPTTrainParams
23+ from src .train .sovits import SovitsTrainParams
2124from src .utils .response import EaseVoiceResponse
2225from src .logger import logger
2326
@@ -275,11 +278,48 @@ async def stop_service(self):
275278 raise HTTPException (status_code = HTTPStatus .INTERNAL_SERVER_ERROR , detail = {"error" : msg })
276279
277280
281+ class TrainAPI :
282+ def __init__ (self ):
283+ self .router = APIRouter ()
284+ self ._register_routes ()
285+
286+ def _register_routes (self ):
287+ self .router .post ("/train/gpt" )(self .train_gpt )
288+ self .router .post ("/train/sovits" )(self .train_sovits )
289+
290+ def train_gpt (self , params : GPTTrainParams ):
291+ result = self ._do_train_gpt (params )
292+
293+ # session_guard wrapper return a dict
294+ if isinstance (result , EaseVoiceResponse ):
295+ return result
296+ logger .error (f"failed to train gpt: { result } " )
297+ raise HTTPException (status_code = HTTPStatus .INTERNAL_SERVER_ERROR , detail = result )
298+
299+ async def train_sovits (self , params : SovitsTrainParams ):
300+ result = self ._do_train_sovits (params )
301+
302+ # session_guard wrapper return a dict
303+ if isinstance (result , EaseVoiceResponse ):
304+ return result
305+ logger .error (f"failed to train sovits: { result } " )
306+ raise HTTPException (status_code = HTTPStatus .INTERNAL_SERVER_ERROR , detail = result )
307+
308+ @session_guard ("TrainGPT" )
309+ def _do_train_gpt (self , params : GPTTrainParams ):
310+ service = TrainGPTService (params )
311+ return service .train ()
312+
313+ @session_guard ("TrainSovits" )
314+ def _do_train_sovits (self , params : SovitsTrainParams ):
315+ service = TrainSovitsService (params )
316+ return service .train ()
317+
318+
278319# Initialize FastAPI and NamespaceService
279320app = FastAPI ()
280321
281322namespace_service = NamespaceService ()
282- voice_service = VoiceCloneService ()
283323namespace_api = NamespaceAPI (namespace_service )
284324app .include_router (namespace_api .router , prefix = "/apis/v1" )
285325
@@ -290,3 +330,9 @@ async def stop_service(self):
290330session_manager = SessionManager ()
291331session_api = SessionAPI (session_manager )
292332app .include_router (session_api .router , prefix = "/apis/v1" )
333+
334+ voice_clone_api = VoiceCloneAPI (session_manager )
335+ app .include_router (voice_clone_api .router , prefix = "/apis/v1" )
336+
337+ train_api = TrainAPI ()
338+ app .include_router (train_api .router , prefix = "/apis/v1" )
0 commit comments