Skip to content

Commit dde2386

Browse files
committed
fix list not exist dir for gpt and sovits train
1 parent fba6d1b commit dde2386

File tree

2 files changed

+28
-18
lines changed

2 files changed

+28
-18
lines changed

src/rest/rest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,7 @@ async def train_gpt(self, params: GPTTrainParams):
397397
# Note: it could not be empty, because the training processing has multiple processes, each process could generate a model name
398398
if params.output_model_name == "":
399399
params.output_model_name = "gpt_" + generate_random_name()
400-
model_path = get_gpt_train_dir(params.output_model_name)
400+
model_path = get_gpt_train_dir(params.project_dir, params.output_model_name)
401401

402402
backtask_with_session_guard(uid, TaskType.train_gpt, asdict(params), start_task_with_subprocess, uid=uid, request=params, cmd_file=TaskCMD.train_gpt)
403403
return EaseVoiceResponse(ResponseStatus.SUCCESS, "GPT training started", uuid=str(uid), data={"model_path": model_path})
@@ -408,7 +408,7 @@ async def train_sovits(self, params: SovitsTrainParams):
408408
# Note: it could not be empty, because the training processing has multiple processes, each process could generate a model name
409409
if params.output_model_name == "":
410410
params.output_model_name = "sovits_" + generate_random_name()
411-
model_path = get_sovits_train_dir(params.output_model_name)
411+
model_path = get_sovits_train_dir(params.project_dir, params.output_model_name)
412412
uid = str(uuid.uuid4())
413413
backtask_with_session_guard(uid, TaskType.train_sovits, asdict(params), start_task_with_subprocess, uid=uid, request=params, cmd_file=TaskCMD.tran_sovits)
414414
return EaseVoiceResponse(ResponseStatus.SUCCESS, "Sovits training started", uuid=uid, data={"model_path": model_path})

src/train/helper.py

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from pathlib import Path
66
from typing import Optional
77

8+
from src.logger import logger
9+
810
train_logs_path = "logs"
911

1012

@@ -33,25 +35,33 @@ def get_sovits_train_dir(project_dir: str, name: Optional[str]):
3335

3436

3537
def list_train_gpts(project_dir: str):
36-
all_gpts = Path(_get_all_gpt_train_output_dir(project_dir))
37-
res = {}
38-
for dir in all_gpts.iterdir():
39-
if dir.is_dir():
40-
for file in dir.iterdir():
41-
if file.is_file() and file.name.endswith(".ckpt"):
42-
res[f"{dir.name}/{file.name}"] = str(file)
43-
return res
38+
try:
39+
all_gpts = Path(_get_all_gpt_train_output_dir(project_dir))
40+
res = {}
41+
for dir in all_gpts.iterdir():
42+
if dir.is_dir():
43+
for file in dir.iterdir():
44+
if file.is_file() and file.name.endswith(".ckpt"):
45+
res[f"{dir.name}/{file.name}"] = str(file)
46+
return res
47+
except Exception as e:
48+
logger.warning(f"list_train_gpts failed: {e}")
49+
return {}
4450

4551

4652
def list_train_sovits(project_dir: str):
47-
all_sovits = Path(_get_all_sovits_train_output_dir(project_dir))
48-
res = {}
49-
for dir in all_sovits.iterdir():
50-
if dir.is_dir():
51-
for file in dir.iterdir():
52-
if file.is_file() and file.name.endswith(".pth"):
53-
res[f"{dir.name}/{file.name}"] = str(file)
54-
return res
53+
try:
54+
all_sovits = Path(_get_all_sovits_train_output_dir(project_dir))
55+
res = {}
56+
for dir in all_sovits.iterdir():
57+
if dir.is_dir():
58+
for file in dir.iterdir():
59+
if file.is_file() and file.name.endswith(".pth"):
60+
res[f"{dir.name}/{file.name}"] = str(file)
61+
return res
62+
except Exception as e:
63+
logger.warning(f"list_train_sovits failed: {e}")
64+
return {}
5565

5666

5767
@dataclass

0 commit comments

Comments
 (0)