Skip to content

Commit e6a3686

Browse files
committed
support check available model api
1 parent cf9e760 commit e6a3686

File tree

6 files changed

+61
-14
lines changed

6 files changed

+61
-14
lines changed

src/rest/rest.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from src.service.train import TrainGPTService, TrainSovitsService
2323
from src.service.voice import VoiceCloneService
2424
from src.train.gpt import GPTTrainParams
25+
from src.train.helper import list_train_gpts, list_train_sovits
2526
from src.train.sovits import SovitsTrainParams
2627
from src.utils.response import EaseVoiceResponse
2728

@@ -211,8 +212,14 @@ def _register_routes(self):
211212
self.router.post("/voiceclone/start")(self.start_service)
212213
self.router.post("/voiceclone/clone")(self.clone)
213214
self.router.get("/voiceclone/stop")(self.stop_service)
215+
self.router.get("/voiceclone/models")(self.get_available_models)
214216
self.router.get("/voiceclone/status")(self.get_status)
215217

218+
async def get_available_models(self):
219+
gpts = ["default"].extend(list_train_gpts().keys())
220+
sovits = ["default"].extend(list_train_sovits().keys())
221+
return {"gpts": gpts, "sovits": sovits}
222+
216223
async def get_status(self):
217224
if self.service is None:
218225
raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail={"error": "voice clone service is not started"})

src/service/voice.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,17 @@
1717

1818
from src.easevoice.inference import InferenceResult, InferenceTask, InferenceTaskData, Runner
1919
from src.logger import logger
20+
from src.train import sovits
21+
from src.train.helper import list_train_gpts, list_train_sovits
2022
from src.utils.response import EaseVoiceResponse, ResponseStatus
2123

24+
2225
class VoiceCloneStatus(Enum):
2326
RUNNING = "Running"
2427
COMPLETED = "Completed"
2528
ERROR = "Error"
2629

30+
2731
class VoiceCloneService:
2832
"""
2933
VoiceService is a long run service that listens for voice clone tasks and processes them.
@@ -40,7 +44,7 @@ def close(self):
4044
self.runner_process.terminate()
4145
self.runner_process.join(timeout=10)
4246
self.runner_process = None
43-
47+
4448
def get_status(self):
4549
if self.runner_process is None:
4650
return VoiceCloneStatus.COMPLETED
@@ -64,6 +68,7 @@ def clone(self, params: dict):
6468
data = InferenceTaskData(**params)
6569
queue = mp.Queue()
6670
infer_task = InferenceTask(result_queue=queue, data=data)
71+
infer_task = self.update_task_path(infer_task)
6772
self.queue.put(infer_task)
6873
result: InferenceResult = infer_task.result_queue.get(timeout=600)
6974
except Exception as e:
@@ -84,3 +89,25 @@ def clone(self, params: dict):
8489
except Exception as e:
8590
logger.error(f"failed to clone voice for {params}, error: {e}", exc_info=True)
8691
return EaseVoiceResponse(ResponseStatus.FAILED, "failed to clone voice")
92+
93+
def update_task_path(self, task: InferenceTask):
94+
if task.data.gpt_path == "default":
95+
task.data.gpt_path = ""
96+
if task.data.sovits_path == "default":
97+
task.data.sovits_path = ""
98+
99+
if task.data.gpt_path != "":
100+
gpts = list_train_gpts()
101+
if task.data.gpt_path in gpts:
102+
task.data.gpt_path = gpts[task.data.gpt_path]
103+
else:
104+
logger.error(f"failed to find gpt model for {task.data.gpt_path}")
105+
raise ValueError(f"failed to find gpt model for {task.data.gpt_path}")
106+
if task.data.sovits_path != "":
107+
sovits = list_train_sovits()
108+
if task.data.sovits_path in sovits:
109+
task.data.sovits_path = sovits[task.data.sovits_path]
110+
else:
111+
logger.error(f"failed to find sovits model for {task.data.sovits_path}")
112+
raise ValueError(f"failed to find sovits model for {task.data.sovits_path}")
113+
return task

src/train/gpt.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from src.easevoice.soundstorm.auto_reg.data.data_module import Text2SemanticDataModule
2121
from src.easevoice.soundstorm.auto_reg.models.t2s_lightning_module import Text2SemanticLightningModule
22-
from src.train.helper import get_gpt_train_dir, train_logs_path, train_ckpt_path
22+
from src.train.helper import get_gpt_train_dir, train_logs_path
2323
from src.utils.config import gpt_config_path, cfg, gpt_pretrained_model_path, semantic_output, text_output_name
2424

2525

@@ -102,7 +102,7 @@ def __init__(self, params: GPTTrainParams):
102102
self.train_input_dir = params.train_input_dir
103103
self.train_output = get_gpt_train_dir(params.output_model_name)
104104
self.train_logs_output = os.path.join(self.train_output, train_logs_path)
105-
self.train_ckpts_output = os.path.join(self.train_output, train_ckpt_path)
105+
self.train_ckpts_output = os.path.join(self.train_output, "ckpt")
106106
os.makedirs(self.train_output, exist_ok=True)
107107
os.makedirs(self.train_logs_output, exist_ok=True)
108108
os.makedirs(self.train_ckpts_output, exist_ok=True)

src/train/helper.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11

22
import datetime
33
import os
4+
from pathlib import Path
45
from typing import Optional
56
from src.utils import config
67

78
train_logs_path = "logs"
8-
train_ckpt_path = "ckpt"
99

1010

1111
def generate_random_name():
@@ -25,8 +25,22 @@ def get_sovits_train_dir(name: Optional[str]):
2525

2626

2727
def list_train_gpts():
28-
return os.listdir(config.all_gpt_train_output_dir)
28+
all_gpts = Path(config.all_gpt_train_output_dir)
29+
res = {}
30+
for dir in all_gpts.iterdir():
31+
if dir.is_dir():
32+
for file in dir.iterdir():
33+
if file.is_file() and file.name.endswith(".ckpt"):
34+
res[f"{dir.name}/{file.name}"] = str(file)
35+
return res
2936

3037

3138
def list_train_sovits():
32-
return os.listdir(config.all_sovits_train_output_dir)
39+
all_sovits = Path(config.all_sovits_train_output_dir)
40+
res = {}
41+
for dir in all_sovits.iterdir():
42+
if dir.is_dir():
43+
for file in dir.iterdir():
44+
if file.is_file() and file.name.endswith(".pth"):
45+
res[f"{dir.name}/{file.name}"] = str(file)
46+
return res

src/train/sovits.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from typing import Any, List, Tuple
55
import torch.distributed as dist
66
import os
7-
from src.train.helper import get_sovits_train_dir, train_ckpt_path, train_logs_path
7+
from src.train.helper import get_sovits_train_dir, train_logs_path
88
from src.utils import config
99
from src.utils import helper
1010
from src.utils.helper import load_json
@@ -134,10 +134,9 @@ def _update_hparams(self, hps: TrainConfig, params: SovitsTrainParams):
134134
hps.data.exp_dir = params.train_input_dir
135135
hps.train.output_dir = get_sovits_train_dir(params.output_model_name)
136136
hps.train.train_logs_dir = os.path.join(hps.train.output_dir, train_logs_path)
137-
hps.train.save_weight_dir = os.path.join(hps.train.output_dir, train_ckpt_path)
137+
hps.train.save_weight_dir = hps.train.output_dir
138138
os.makedirs(hps.train.output_dir, exist_ok=True)
139139
os.makedirs(hps.train.train_logs_dir, exist_ok=True)
140-
os.makedirs(hps.train.save_weight_dir, exist_ok=True)
141140

142141
# set pretrained model path
143142
if params.pretrained_s2G == "":
@@ -570,7 +569,7 @@ def _train_and_evaluate(
570569
hps.train.learning_rate,
571570
epoch,
572571
os.path.join(
573-
hps.train.save_weight_dir, f"sovits_G_epoch{epoch}_step{self.step}.pth"
572+
hps.train.train_logs_dir, f"sovits_G_epoch{epoch}_step{self.step}.pth"
574573
),
575574
)
576575
ckpt.save_checkpoint(
@@ -579,7 +578,7 @@ def _train_and_evaluate(
579578
hps.train.learning_rate,
580579
epoch,
581580
os.path.join(
582-
hps.train.save_weight_dir, f"sovits_D_epoch{epoch}_step{self.step}.pth"
581+
hps.train.train_logs_dir, f"sovits_D_epoch{epoch}_step{self.step}.pth"
583582
),
584583
)
585584
else:
@@ -589,7 +588,7 @@ def _train_and_evaluate(
589588
hps.train.learning_rate,
590589
epoch,
591590
os.path.join(
592-
hps.train.save_weight_dir, "sovits_G_latest.pth"
591+
hps.train.train_logs_dir, "sovits_G_latest.pth"
593592
),
594593
)
595594
ckpt.save_checkpoint(
@@ -598,7 +597,7 @@ def _train_and_evaluate(
598597
hps.train.learning_rate,
599598
epoch,
600599
os.path.join(
601-
hps.train.save_weight_dir, "sovits_D_latest.pth"
600+
hps.train.train_logs_dir, "sovits_D_latest.pth"
602601
),
603602
)
604603
if rank == 0 and hps.train.if_save_every_weights == True:

tests/train_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
class TestTrain(unittest.TestCase):
1111
gpt_service = TrainGPTService(gpt_params=GPTTrainParams(
12-
output_path="./output",
12+
train_input_dir="./output",
1313
output_model_name="test",
1414
))
1515
sovits_service = TrainSovitsService(sovits_params=SovitsTrainParams(

0 commit comments

Comments
 (0)