11import base64
2- from concurrent .futures import thread
32from enum import Enum
43import gc
54import io
65import multiprocessing as mp
7- import os
8- import queue
9- import threading
10- import time
116import numpy as np
12- from scipy .io import wavfile
137import soundfile as sf
8+ import torch
149
15- from src .api .api import ServiceNames , TaskStatus , VoiceCloneProgress
1610
17-
18- from src .easevoice .inference import InferenceResult , InferenceTask , InferenceTaskData , Runner
11+ from src .easevoice .inference import InferenceTaskData , Runner
1912from src .logger import logger
20- from src .train import sovits
2113from src .train .helper import list_train_gpts , list_train_sovits
2214from src .utils .response import EaseVoiceResponse , ResponseStatus
2315
@@ -35,79 +27,50 @@ class VoiceCloneService:
3527
3628 def __init__ (self ):
3729 self .queue = mp .Queue ()
38- self .runner_process = mp .Process (target = VoiceCloneService ._init_runner , args = (self .queue ,))
39- self .runner_process .start ()
30+ self .runner_process = Runner ()
4031
4132 def close (self ):
4233 if self .runner_process is not None :
43- self .queue .put (1 )
44- self .runner_process .terminate ()
45- self .runner_process .join (timeout = 10 )
4634 self .runner_process = None
35+ gc .collect ()
36+ torch .cuda .empty_cache ()
4737
4838 def get_status (self ):
4939 if self .runner_process is None :
5040 return VoiceCloneStatus .COMPLETED
51- elif self .runner_process .is_alive ():
52- return VoiceCloneStatus .RUNNING
53- else :
54- return VoiceCloneStatus .ERROR
55-
56- @staticmethod
57- def _init_runner (queue : mp .Queue ):
58- """
59- Call this method to start the runner process
60- """
61- runner = Runner (queue )
62- runner .run ()
63- print ("Voice clone runner process exited" )
64- gc .collect ()
41+ return VoiceCloneStatus .RUNNING
6542
6643 def clone (self , params : dict ):
67- try :
68- data = InferenceTaskData (** params )
69- queue = mp .Queue ()
70- infer_task = InferenceTask (result_queue = queue , data = data )
71- infer_task = self .update_task_path (infer_task )
72- self .queue .put (infer_task )
73- result : InferenceResult = infer_task .result_queue .get (timeout = 600 )
74- except Exception as e :
75- logger .error (f"failed to clone voice for { params } , error: { e } " , exc_info = True )
76- result = InferenceResult (error = str (e ))
44+ data = InferenceTaskData (** params )
45+ data = self .update_task_path (data )
46+ items , seed = self .runner_process .inference (data ) # pyright: ignore
7747
78- if result .error :
79- logger .error (f"failed to clone voice for { params } , error: { result .error } " )
80- return EaseVoiceResponse (ResponseStatus .FAILED , result .error )
81- else :
82- try :
83- sampling_rate = result .items [0 ][0 ]
84- data = np .concatenate ([item [1 ] for item in result .items ])
85- buffer = io .BytesIO ()
86- sf .write (buffer , data , sampling_rate , format = "WAV" )
87- audio = base64 .b64encode (buffer .getvalue ()).decode ("utf-8" )
88- return EaseVoiceResponse (ResponseStatus .SUCCESS , "Voice cloned successfully" , {"sampling_rate" : sampling_rate , "audio" : audio })
89- except Exception as e :
90- logger .error (f"failed to clone voice for { params } , error: { e } " , exc_info = True )
91- return EaseVoiceResponse (ResponseStatus .FAILED , "failed to clone voice" )
48+ sampling_rate = items [0 ][0 ]
49+ data = np .concatenate ([item [1 ] for item in items ])
50+ buffer = io .BytesIO ()
51+ sf .write (buffer , data , sampling_rate , format = "WAV" )
52+ audio = base64 .b64encode (buffer .getvalue ()).decode ("utf-8" )
53+
54+ return EaseVoiceResponse (ResponseStatus .SUCCESS , "Voice cloned successfully" , {"sampling_rate" : sampling_rate , "audio" : audio })
9255
93- def update_task_path (self , task : InferenceTask ):
94- if task . data .gpt_path == "default" :
95- task . data .gpt_path = ""
96- if task . data .sovits_path == "default" :
97- task . data .sovits_path = ""
56+ def update_task_path (self , data : InferenceTaskData ):
57+ if data .gpt_path == "default" :
58+ data .gpt_path = ""
59+ if data .sovits_path == "default" :
60+ data .sovits_path = ""
9861
99- if task . data .gpt_path != "" :
62+ if data .gpt_path != "" :
10063 gpts = list_train_gpts ()
101- if task . data .gpt_path in gpts :
102- task . data .gpt_path = gpts [task . data .gpt_path ]
64+ if data .gpt_path in gpts :
65+ data .gpt_path = gpts [data .gpt_path ]
10366 else :
104- logger .error (f"failed to find gpt model for { task . data .gpt_path } " )
105- raise ValueError (f"failed to find gpt model for { task . data .gpt_path } " )
106- if task . data .sovits_path != "" :
67+ logger .error (f"failed to find gpt model for { data .gpt_path } " )
68+ raise ValueError (f"failed to find gpt model for { data .gpt_path } " )
69+ if data .sovits_path != "" :
10770 sovits = list_train_sovits ()
108- if task . data .sovits_path in sovits :
109- task . data .sovits_path = sovits [task . data .sovits_path ]
71+ if data .sovits_path in sovits :
72+ data .sovits_path = sovits [data .sovits_path ]
11073 else :
111- logger .error (f"failed to find sovits model for { task . data .sovits_path } " )
112- raise ValueError (f"failed to find sovits model for { task . data .sovits_path } " )
113- return task
74+ logger .error (f"failed to find sovits model for { data .sovits_path } " )
75+ raise ValueError (f"failed to find sovits model for { data .sovits_path } " )
76+ return data
0 commit comments