Skip to content

Commit d08380c

Browse files
committed
refacotor voice clone api
1 parent 5d211b2 commit d08380c

File tree

4 files changed

+51
-126
lines changed

4 files changed

+51
-126
lines changed

src/easevoice/inference/__init__.py

Lines changed: 17 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,6 @@
2020
logging.getLogger("torchaudio._extension").setLevel(logging.ERROR)
2121

2222

23-
@dataclasses.dataclass
24-
class InferenceResult:
25-
"""
26-
Result of inference
27-
"""
28-
items: list = dataclasses.field(default_factory=list)
29-
seed: int = -1
30-
error: Optional[str] = None
31-
32-
3323
@dataclasses.dataclass
3424
class InferenceTaskData:
3525
"""
@@ -42,29 +32,20 @@ class InferenceTaskData:
4232
prompt_lang: str
4333
text_split_method: str
4434
aux_ref_audio_paths: list = dataclasses.field(default_factory=list)
45-
seed = -1
46-
top_k = 5
47-
top_p = 1
48-
temperature = 1
49-
batch_size = 20
50-
speed_factor = 1.0
51-
ref_text_free = False
52-
split_bucket = True
53-
fragment_interval = 0.3
54-
keep_random = True
55-
parallel_infer = True
56-
repetition_penalty = 1.3
57-
sovits_path = ""
58-
gpt_path = ""
59-
60-
61-
@dataclasses.dataclass
62-
class InferenceTask:
63-
"""
64-
Task for inference
65-
"""
66-
result_queue: multiprocessing.Queue
67-
data: InferenceTaskData
35+
seed: int = -1
36+
top_k: int = 5
37+
top_p: int = 1
38+
temperature: float = 1.0
39+
batch_size: int = 20
40+
speed_factor: float = 1.0
41+
ref_text_free: bool = False
42+
split_bucket: bool = True
43+
fragment_interval: float = 0.3
44+
keep_random: bool = True
45+
parallel_infer: bool = True
46+
repetition_penalty: float = 1.3
47+
sovits_path: str = ""
48+
gpt_path: str = ""
6849

6950

7051
class Runner:
@@ -75,42 +56,23 @@ class Runner:
7556
Wait InferenceResult from the queue
7657
"""
7758

78-
def __init__(self, queue: multiprocessing.Queue):
59+
def __init__(self):
7960
tts_config = TTSConfig(os.path.join(get_base_path(), "configs", "tts_infer.yaml"))
8061
logger.info(f"tts config: {tts_config}")
8162

8263
self.tts_config = tts_config
8364
self.tts_pipeline = TTS(tts_config)
84-
self.task_queue = queue
85-
self.done = False
86-
87-
def run(self):
88-
while not self.done:
89-
task: Union[InferenceTask, int] = self.task_queue.get()
90-
if isinstance(task, int):
91-
logger.info("Received stop signal")
92-
return
93-
else:
94-
try:
95-
items, seed = self._inference(task)
96-
task.result_queue.put(
97-
InferenceResult(items=items, seed=seed)
98-
)
99-
except Exception as e:
100-
logger.error(f"error: {e}")
101-
task.result_queue.put(InferenceResult(error=str(e)))
10265

103-
def _inference(self, task: InferenceTask):
66+
def inference(self, data: InferenceTaskData):
10467
# change weight based on task
10568
try:
106-
self.tts_pipeline.update_weights(task.data.sovits_path, task.data.gpt_path)
69+
self.tts_pipeline.update_weights(data.sovits_path, data.gpt_path)
10770
except Exception as e:
10871
logger.error(f"failed to update weights: {e}")
10972
# change back to default weights
11073
self.tts_pipeline.update_weights("", "")
11174
raise e
11275

113-
data = task.data
11476
seed = -1 if data.keep_random else data.seed
11577
actual_seed = seed if seed not in [-1, "", None] else random.randrange(1 << 32)
11678
inputs = {

src/easevoice/inference/tts.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -678,7 +678,7 @@ def run(self, inputs: dict):
678678

679679
if ref_audio_path in [None, ""] and \
680680
((self.prompt_cache["prompt_semantic"] is None) or (self.prompt_cache["refer_spec"] in [None, []])):
681-
raise ValueError("ref_audio_path cannot be empty, when the reference audio is not set using set_ref_audio()")
681+
raise ValueError("ref_audio_path cannot be empty")
682682

683683
###### setting reference audio and prompt text preprocessing ########
684684
t0 = ttime()

src/rest/rest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ async def clone(self, request: dict):
254254
try:
255255
return self.service.clone(request)
256256
except Exception as e:
257-
logger.error(f"failed to clone voice for {request}, err: {e}")
257+
logger.error(f"failed to clone voice for {request}, err: {e}", exc_info=True)
258258
raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail={"error": f"failed to clone voice: {e}"})
259259

260260
async def stop_service(self):

src/service/voice.py

Lines changed: 32 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,15 @@
11
import base64
2-
from concurrent.futures import thread
32
from enum import Enum
43
import gc
54
import io
65
import multiprocessing as mp
7-
import os
8-
import queue
9-
import threading
10-
import time
116
import numpy as np
12-
from scipy.io import wavfile
137
import soundfile as sf
8+
import torch
149

15-
from src.api.api import ServiceNames, TaskStatus, VoiceCloneProgress
1610

17-
18-
from src.easevoice.inference import InferenceResult, InferenceTask, InferenceTaskData, Runner
11+
from src.easevoice.inference import InferenceTaskData, Runner
1912
from src.logger import logger
20-
from src.train import sovits
2113
from src.train.helper import list_train_gpts, list_train_sovits
2214
from src.utils.response import EaseVoiceResponse, ResponseStatus
2315

@@ -35,79 +27,50 @@ class VoiceCloneService:
3527

3628
def __init__(self):
3729
self.queue = mp.Queue()
38-
self.runner_process = mp.Process(target=VoiceCloneService._init_runner, args=(self.queue,))
39-
self.runner_process.start()
30+
self.runner_process = Runner()
4031

4132
def close(self):
4233
if self.runner_process is not None:
43-
self.queue.put(1)
44-
self.runner_process.terminate()
45-
self.runner_process.join(timeout=10)
4634
self.runner_process = None
35+
gc.collect()
36+
torch.cuda.empty_cache()
4737

4838
def get_status(self):
4939
if self.runner_process is None:
5040
return VoiceCloneStatus.COMPLETED
51-
elif self.runner_process.is_alive():
52-
return VoiceCloneStatus.RUNNING
53-
else:
54-
return VoiceCloneStatus.ERROR
55-
56-
@staticmethod
57-
def _init_runner(queue: mp.Queue):
58-
"""
59-
Call this method to start the runner process
60-
"""
61-
runner = Runner(queue)
62-
runner.run()
63-
print("Voice clone runner process exited")
64-
gc.collect()
41+
return VoiceCloneStatus.RUNNING
6542

6643
def clone(self, params: dict):
67-
try:
68-
data = InferenceTaskData(**params)
69-
queue = mp.Queue()
70-
infer_task = InferenceTask(result_queue=queue, data=data)
71-
infer_task = self.update_task_path(infer_task)
72-
self.queue.put(infer_task)
73-
result: InferenceResult = infer_task.result_queue.get(timeout=600)
74-
except Exception as e:
75-
logger.error(f"failed to clone voice for {params}, error: {e}", exc_info=True)
76-
result = InferenceResult(error=str(e))
44+
data = InferenceTaskData(**params)
45+
data = self.update_task_path(data)
46+
items, seed = self.runner_process.inference(data) # pyright: ignore
7747

78-
if result.error:
79-
logger.error(f"failed to clone voice for {params}, error: {result.error}")
80-
return EaseVoiceResponse(ResponseStatus.FAILED, result.error)
81-
else:
82-
try:
83-
sampling_rate = result.items[0][0]
84-
data = np.concatenate([item[1] for item in result.items])
85-
buffer = io.BytesIO()
86-
sf.write(buffer, data, sampling_rate, format="WAV")
87-
audio = base64.b64encode(buffer.getvalue()).decode("utf-8")
88-
return EaseVoiceResponse(ResponseStatus.SUCCESS, "Voice cloned successfully", {"sampling_rate": sampling_rate, "audio": audio})
89-
except Exception as e:
90-
logger.error(f"failed to clone voice for {params}, error: {e}", exc_info=True)
91-
return EaseVoiceResponse(ResponseStatus.FAILED, "failed to clone voice")
48+
sampling_rate = items[0][0]
49+
data = np.concatenate([item[1] for item in items])
50+
buffer = io.BytesIO()
51+
sf.write(buffer, data, sampling_rate, format="WAV")
52+
audio = base64.b64encode(buffer.getvalue()).decode("utf-8")
53+
54+
return EaseVoiceResponse(ResponseStatus.SUCCESS, "Voice cloned successfully", {"sampling_rate": sampling_rate, "audio": audio})
9255

93-
def update_task_path(self, task: InferenceTask):
94-
if task.data.gpt_path == "default":
95-
task.data.gpt_path = ""
96-
if task.data.sovits_path == "default":
97-
task.data.sovits_path = ""
56+
def update_task_path(self, data: InferenceTaskData):
57+
if data.gpt_path == "default":
58+
data.gpt_path = ""
59+
if data.sovits_path == "default":
60+
data.sovits_path = ""
9861

99-
if task.data.gpt_path != "":
62+
if data.gpt_path != "":
10063
gpts = list_train_gpts()
101-
if task.data.gpt_path in gpts:
102-
task.data.gpt_path = gpts[task.data.gpt_path]
64+
if data.gpt_path in gpts:
65+
data.gpt_path = gpts[data.gpt_path]
10366
else:
104-
logger.error(f"failed to find gpt model for {task.data.gpt_path}")
105-
raise ValueError(f"failed to find gpt model for {task.data.gpt_path}")
106-
if task.data.sovits_path != "":
67+
logger.error(f"failed to find gpt model for {data.gpt_path}")
68+
raise ValueError(f"failed to find gpt model for {data.gpt_path}")
69+
if data.sovits_path != "":
10770
sovits = list_train_sovits()
108-
if task.data.sovits_path in sovits:
109-
task.data.sovits_path = sovits[task.data.sovits_path]
71+
if data.sovits_path in sovits:
72+
data.sovits_path = sovits[data.sovits_path]
11073
else:
111-
logger.error(f"failed to find sovits model for {task.data.sovits_path}")
112-
raise ValueError(f"failed to find sovits model for {task.data.sovits_path}")
113-
return task
74+
logger.error(f"failed to find sovits model for {data.sovits_path}")
75+
raise ValueError(f"failed to find sovits model for {data.sovits_path}")
76+
return data

0 commit comments

Comments
 (0)