1717
1818from src .easevoice .inference import InferenceResult , InferenceTask , InferenceTaskData , Runner
1919from src .logger import logger
20+ from src .train import sovits
21+ from src .train .helper import list_train_gpts , list_train_sovits
2022from src .utils .response import EaseVoiceResponse , ResponseStatus
2123
24+
2225class VoiceCloneStatus (Enum ):
2326 RUNNING = "Running"
2427 COMPLETED = "Completed"
2528 ERROR = "Error"
2629
30+
2731class VoiceCloneService :
2832 """
2933 VoiceService is a long run service that listens for voice clone tasks and processes them.
@@ -40,7 +44,7 @@ def close(self):
4044 self .runner_process .terminate ()
4145 self .runner_process .join (timeout = 10 )
4246 self .runner_process = None
43-
47+
4448 def get_status (self ):
4549 if self .runner_process is None :
4650 return VoiceCloneStatus .COMPLETED
@@ -64,6 +68,7 @@ def clone(self, params: dict):
6468 data = InferenceTaskData (** params )
6569 queue = mp .Queue ()
6670 infer_task = InferenceTask (result_queue = queue , data = data )
71+ infer_task = self .update_task_path (infer_task )
6772 self .queue .put (infer_task )
6873 result : InferenceResult = infer_task .result_queue .get (timeout = 600 )
6974 except Exception as e :
@@ -84,3 +89,25 @@ def clone(self, params: dict):
8489 except Exception as e :
8590 logger .error (f"failed to clone voice for { params } , error: { e } " , exc_info = True )
8691 return EaseVoiceResponse (ResponseStatus .FAILED , "failed to clone voice" )
92+
93+ def update_task_path (self , task : InferenceTask ):
94+ if task .data .gpt_path == "default" :
95+ task .data .gpt_path = ""
96+ if task .data .sovits_path == "default" :
97+ task .data .sovits_path = ""
98+
99+ if task .data .gpt_path != "" :
100+ gpts = list_train_gpts ()
101+ if task .data .gpt_path in gpts :
102+ task .data .gpt_path = gpts [task .data .gpt_path ]
103+ else :
104+ logger .error (f"failed to find gpt model for { task .data .gpt_path } " )
105+ raise ValueError (f"failed to find gpt model for { task .data .gpt_path } " )
106+ if task .data .sovits_path != "" :
107+ sovits = list_train_sovits ()
108+ if task .data .sovits_path in sovits :
109+ task .data .sovits_path = sovits [task .data .sovits_path ]
110+ else :
111+ logger .error (f"failed to find sovits model for { task .data .sovits_path } " )
112+ raise ValueError (f"failed to find sovits model for { task .data .sovits_path } " )
113+ return task
0 commit comments