Skip to content

Commit cf569eb

Browse files
committed
dynamic update inference weights based on request
1 parent 80990b2 commit cf569eb

File tree

3 files changed

+36
-3
lines changed

3 files changed

+36
-3
lines changed

src/easevoice/inference/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ class InferenceTaskData:
5454
keep_random = True
5555
parallel_infer = True
5656
repetition_penalty = 1.3
57+
sovits_path = ""
58+
gpt_path = ""
5759

5860

5961
@dataclasses.dataclass
@@ -99,6 +101,15 @@ def run(self):
99101
task.result_queue.put(InferenceResult(error=str(e)))
100102

101103
def _inference(self, task: InferenceTask):
104+
# change weight based on task
105+
try:
106+
self.tts_pipeline.update_weights(task.data.sovits_path, task.data.gpt_path)
107+
except Exception as e:
108+
logger.error(f"failed to update weights: {e}")
109+
# change back to default weights
110+
self.tts_pipeline.update_weights("", "")
111+
raise e
112+
102113
data = task.data
103114
seed = -1 if data.keep_random else data.seed
104115
actual_seed = seed if seed not in [-1, "", None] else random.randrange(1 << 32)

src/easevoice/inference/tts.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,9 @@ def __init__(self, configs: Union[dict, str, TTSConfig]):
194194
self.bert_model: AutoModelForMaskedLM = None # pyright: ignore
195195
self.cnhuhbert_model: CNHubert = None # pyright: ignore
196196

197+
self.current_sovits_path = configs.vits_weights_path # pyright: ignore
198+
self.current_gpt_path = configs.t2s_weights_path # pyright: ignore
199+
197200
self._init_models()
198201

199202
self.text_preprocessor: TextPreprocessor = \
@@ -216,6 +219,27 @@ def __init__(self, configs: Union[dict, str, TTSConfig]):
216219
self.stop_flag: bool = False
217220
self.precision: torch.dtype = torch.float16 if self.configs.is_half else torch.float32
218221

222+
def update_weights(self, sovits_path: str, gpt_path: str):
223+
if sovits_path != self.current_sovits_path:
224+
if sovits_path == "":
225+
# empty sovits path, use default path
226+
if self.current_sovits_path != self.configs.vits_weights_path:
227+
self.current_sovits_path = self.configs.vits_weights_path
228+
self.init_vits_weights(self.current_sovits_path)
229+
else:
230+
self.current_sovits_path = sovits_path
231+
self.init_vits_weights(self.current_sovits_path)
232+
233+
if gpt_path != self.current_gpt_path:
234+
if gpt_path == "":
235+
# empty gpt path, use default path
236+
if self.current_gpt_path != self.configs.t2s_weights_path:
237+
self.current_gpt_path = self.configs.t2s_weights_path
238+
self.init_t2s_weights(self.current_gpt_path)
239+
else:
240+
self.current_gpt_path = gpt_path
241+
self.init_t2s_weights(self.current_gpt_path)
242+
219243
def _init_models(self,):
220244
self.init_t2s_weights(self.configs.t2s_weights_path)
221245
self.init_vits_weights(self.configs.vits_weights_path)

tests/easevoice/tts_test.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
from src.easevoice.inference import Runner
77
from src.logger import logger
8-
from src.service.task import TaskService
98
from src.service.voice import VoiceCloneService
109

1110

@@ -19,8 +18,7 @@ def test_tts_runner(self):
1918
queue.put(1)
2019

2120
def test_voice_clone_service(self):
22-
task_service = TaskService()
23-
voice_service = VoiceCloneService(task_service)
21+
voice_service = VoiceCloneService()
2422
time.sleep(5)
2523
voice_service.close()
2624

0 commit comments

Comments
 (0)