Skip to content

Commit 54eb6d5

Browse files
committed
update task service, update vocie clone service
1 parent 3e1560b commit 54eb6d5

File tree

6 files changed

+117
-38
lines changed

6 files changed

+117
-38
lines changed

models/pretrained/.gitkeep

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
Please download model from
1+
Please download model from https://huggingface.co/lj1995/GPT-SoVITS

src/api/api.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,19 @@ class TaskStatus:
1818
FAILED = "FAILED"
1919

2020

21-
class Progress(object):
21+
class ServiceNames:
22+
AUDIO = "audio"
23+
VOICE_CLONE = "voice_clone"
24+
25+
26+
class Progress(BaseModel):
2227
"""Progress of a task."""
2328
status: str = TaskStatus.PENDING
2429
current_step: str = ""
2530
total_steps: int = 0
2631
completed_steps: int = 0
2732
current_step_progress: int = 0
33+
message: str = ""
2834

2935

3036
class AudioTaskProgressInitial(Progress):
@@ -36,6 +42,16 @@ class AudioTaskProgressInitial(Progress):
3642
current_step_progress: int = 0
3743

3844

45+
class VoiceCloneProgress(Progress):
46+
"""Progress of a task."""
47+
status: str = TaskStatus.PENDING
48+
current_step: str = ""
49+
total_steps: int = 1
50+
completed_steps: int = 0
51+
current_step_progress: int = 0
52+
message: str = ""
53+
54+
3955
# Task models
4056
class Task(BaseModel):
4157
"""Task model."""

src/easevoice/inference/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,19 @@
2222

2323
@dataclasses.dataclass
2424
class InferenceResult:
25+
"""
26+
Result of inference
27+
"""
2528
items: list = dataclasses.field(default_factory=list)
2629
seed: int = -1
2730
error: Optional[str] = None
2831

2932

3033
@dataclasses.dataclass
3134
class InferenceTaskData:
35+
"""
36+
Data for inference
37+
"""
3238
text: str
3339
text_lang: str
3440
ref_audio_path: str
@@ -52,6 +58,9 @@ class InferenceTaskData:
5258

5359
@dataclasses.dataclass
5460
class InferenceTask:
61+
"""
62+
Task for inference
63+
"""
5564
result_queue: multiprocessing.Queue
5665
data: InferenceTaskData
5766

src/rest/rest.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
from src.service.task import TaskService
33
from src.service.file import FileService
44
from src.service.audio import AudioService
5+
from src.service.voice import VoiceCloneService
56
from src.api.api import (
7+
Progress,
68
Task,
79
CreateTaskResponse,
810
UpdateTaskRequest,
@@ -196,6 +198,7 @@ async def delete_files(self, request: DeleteFilesRequest):
196198
app = FastAPI()
197199

198200
task_service = TaskService()
201+
voice_service = VoiceCloneService(task_service=task_service)
199202
task_api = TaskAPI(task_service)
200203
app.include_router(task_api.router, prefix="/apis/v1")
201204

src/service/task.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import uuid
33
import json
44
from datetime import datetime, timezone
5-
from typing import List, Optional
5+
from typing import Callable, Dict, List, Optional
66
from src.api.api import Task, Progress, TaskStatus
77

88

@@ -22,10 +22,22 @@ def __init__(self, base_dir: Optional[str] = None):
2222
self.base_dir = base_dir
2323
os.makedirs(self.base_dir, exist_ok=True)
2424

25+
tasks = self.get_tasks()
26+
self._tasks: Dict[str, Task] = {}
27+
for task in tasks:
28+
self._tasks[task.taskID] = task
29+
2530
def _task_metadata_path(self, task_id: str) -> str:
2631
"""Get the path to the task metadata file."""
2732
return os.path.join(self.base_dir, task_id, ".metadata.json")
2833

34+
def filter_tasks(self, fn: Callable[[Task], bool]) -> List[Task]:
35+
"""
36+
Filter task using a function.
37+
For example, to get all pending tasks for a service.
38+
"""
39+
return sorted(list(filter(fn, self._tasks.values())), key=lambda t: t.createdAt)
40+
2941
def create_task(self, service_name: str, args: dict) -> Task:
3042
"""Create a new task."""
3143
task_id = str(uuid.uuid4())
@@ -43,10 +55,12 @@ def create_task(self, service_name: str, args: dict) -> Task:
4355
args=args,
4456
progress=Progress(),
4557
)
58+
self._tasks[task_id] = task
4659
return task
4760

4861
def submit_task(self, task: Task):
4962
self._save_task_metadata(task)
63+
self._tasks[task.taskID] = task
5064

5165
def get_tasks(self) -> List[Task]:
5266
"""Get all tasks."""
@@ -57,6 +71,7 @@ def get_tasks(self) -> List[Task]:
5771
try:
5872
task = self._load_task_metadata(task_id)
5973
tasks.append(task)
74+
self._tasks[task_id] = task
6075
except FileNotFoundError:
6176
pass # Skip invalid tasks
6277
return tasks
@@ -66,11 +81,13 @@ def update_task(self, task_id: str, name: str) -> Task:
6681
task = self._load_task_metadata(task_id)
6782
task.name = name
6883
self._save_task_metadata(task)
84+
self._tasks[task.taskID] = task
6985
return task
7086

7187
def delete_task(self, task_id: str):
7288
"""Delete a task."""
7389
task = self._load_task_metadata(task_id)
90+
self._tasks.pop(task_id)
7491
self._delete_directory(task.homePath)
7592

7693
def _save_task_metadata(self, task: Task):

src/service/voice.py

Lines changed: 69 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,57 +1,91 @@
1+
from concurrent.futures import thread
12
import gc
23
import multiprocessing as mp
4+
import os
5+
import queue
36
import threading
7+
import time
8+
import numpy as np
9+
from scipy.io import wavfile
410

5-
from torch import mul
6-
11+
from src.api.api import ServiceNames, TaskStatus, VoiceCloneProgress
12+
from src.service.task import TaskService
713
from ..utils.response import EaseVoiceResponse, ResponseStatus
814

915

1016
from ..easevoice.inference import InferenceResult, InferenceTask, InferenceTaskData, Runner
1117
from ..logger import logger
18+
from src.easevoice import inference
19+
20+
21+
class VoiceCloneService:
22+
"""
23+
VoiceService is a long run service that listens for voice clone tasks and processes them.
24+
"""
1225

26+
def __init__(self, task_service: TaskService):
27+
self.task_service = task_service
1328

14-
class VoiceService:
15-
def __init__(self):
1629
self.queue = mp.Queue()
17-
self.runner_process = mp.Process(target=VoiceService.init_runner, args=(self.queue,))
30+
self.runner_process = mp.Process(target=VoiceCloneService._init_runner, args=(self.queue,))
1831
self.runner_process.start()
19-
self.locker = threading.Lock()
32+
33+
self._run_tasks = threading.Thread(target=self._run)
34+
self._run_tasks.start()
2035

2136
@staticmethod
22-
def init_runner(queue: mp.Queue):
37+
def _init_runner(queue: mp.Queue):
2338
"""
2439
Call this method to start the runner process
2540
"""
2641
runner = Runner(queue)
2742
runner.run()
2843
gc.collect()
2944

30-
def clone(self, input: dict):
31-
ok = self.locker.acquire(timeout=5)
32-
if not ok:
33-
return EaseVoiceResponse(ResponseStatus.FAILED, "There is another task running, please try again later")
34-
35-
try:
36-
data = InferenceTaskData(**input)
37-
queue = mp.Queue()
38-
task = InferenceTask(result_queue=queue, data=data)
39-
self.queue.put(task)
40-
result: InferenceResult = task.result_queue.get(timeout=600)
41-
except Exception as e:
42-
logger.error(f"failed to clone voice for {input}, error: {e}", exc_info=True)
43-
result = InferenceResult(error=str(e))
44-
45-
finally:
46-
self.locker.release()
47-
48-
if result.error:
49-
return EaseVoiceResponse(ResponseStatus.FAILED, result.error)
50-
51-
return EaseVoiceResponse(
52-
ResponseStatus.SUCCESS,
53-
"Cloned voice successfully",
54-
{
55-
"items": result.items,
56-
"seed": result.seed
57-
})
45+
def _run(self):
46+
while True:
47+
tasks = self.task_service.filter_tasks(lambda t: t.service_name == ServiceNames.VOICE_CLONE and t.progress.status == TaskStatus.PENDING)
48+
if len(tasks) == 0:
49+
logger.debug("No pending tasks found for voice clone")
50+
else:
51+
task = tasks[0]
52+
logger.info(f"Processing task {task.taskID}, args: {task.args}")
53+
54+
task.progress.status = TaskStatus.IN_PROGRESS
55+
self.task_service.submit_task(task)
56+
57+
try:
58+
data = InferenceTaskData(**task.args)
59+
queue = mp.Queue()
60+
infer_task = InferenceTask(result_queue=queue, data=data)
61+
self.queue.put(infer_task)
62+
result: InferenceResult = infer_task.result_queue.get(timeout=600)
63+
except Exception as e:
64+
logger.error(f"failed to clone voice for {task.args}, error: {e}", exc_info=True)
65+
result = InferenceResult(error=str(e))
66+
67+
if result.error:
68+
progress = VoiceCloneProgress(status=TaskStatus.FAILED, current_step="Failed", total_steps=1, completed_steps=1, current_step_progress=100, message=result.error)
69+
task.progress = progress
70+
self.task_service.submit_task(task)
71+
logger.error(f"failed to clone voice for {task.args}, error: {result.error}")
72+
else:
73+
try:
74+
sampling_rate = result.items[0][0]
75+
audio = np.concatenate([item[1] for item in result.items])
76+
output_file = os.path.join(task.homePath, "output.wav")
77+
wavfile.write(output_file, sampling_rate, audio)
78+
79+
progress = VoiceCloneProgress(status=TaskStatus.COMPLETED, current_step="Completed", total_steps=1, completed_steps=1, current_step_progress=100)
80+
task.progress = progress
81+
self.task_service.submit_task(task)
82+
logger.info(f"Successfully cloned voice for {task.args}")
83+
84+
except Exception as e:
85+
logger.error(f"failed to clone voice for {task.args}, error: {e}", exc_info=True)
86+
progress = VoiceCloneProgress(status=TaskStatus.FAILED, current_step="Failed", total_steps=1, completed_steps=1, current_step_progress=100, message=str(e))
87+
task.progress = progress
88+
self.task_service.submit_task(task)
89+
logger.error(f"failed to clone voice for {task.args}, error: {e}")
90+
91+
time.sleep(1)

0 commit comments

Comments
 (0)