Skip to content

Commit 4cd5ba8

Browse files
committed
fix: f5tts rename #267
1 parent 5d936f1 commit 4cd5ba8

File tree

4 files changed

+32
-19
lines changed

4 files changed

+32
-19
lines changed

modules/core/models/tts/F5/F5ttsApi.py

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,10 @@
2929

3030
# NOTE: 目前不支持 bigvgan 因为需要引入外部库,并且似乎没什么特别区别
3131
class F5TTS:
32+
3233
def __init__(
3334
self,
35+
ckpt_file: Path,
3436
model: Literal["F5TTS_v1_Base", "F5TTS_Base"] = "F5TTS_v1_Base",
3537
ode_method="euler",
3638
use_ema=True,
@@ -47,13 +49,6 @@ def __init__(
4749
/ "configs"
4850
/ f"{model}.yaml"
4951
)
50-
ckpt_file = (
51-
Path(os.getcwd())
52-
/ "models"
53-
/ "F5-TTS"
54-
/ model
55-
/ "model_1200000.safetensors"
56-
)
5752
vocab_file = (
5853
Path(os.getcwd())
5954
/ "modules"
@@ -64,16 +59,6 @@ def __init__(
6459
/ "Emilia_ZH_EN_pinyin"
6560
/ f"vocab.txt"
6661
)
67-
if not ckpt_file.exists():
68-
# NOTE: 不存在,提示使用脚本下载 `python -m scripts.downloader.f5_tts_v1 --source huggingface`
69-
if model == "F5TTS_Base":
70-
raise ValueError(
71-
f"F5TTS model {model} not found, please download it manually using `python -m scripts.downloader.f5_tts --source huggingface`"
72-
)
73-
else:
74-
raise ValueError(
75-
f"F5TTS model {model} not found, please download it manually using `python -m scripts.downloader.f5_tts_v1 --source huggingface`"
76-
)
7762

7863
# Path to str path, 因为f5的函数不支持Path
7964
ckpt_file = str(ckpt_file)

modules/core/models/tts/F5TtsModel.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,16 @@ def __init__(self) -> None:
2323
super().__init__("f5-tts")
2424

2525
v0_6_model_path = Path("./models/F5-TTS/F5TTS_Base/model_1200000.safetensors")
26-
v1_model_path = Path("./models/F5-TTS/F5TTS_v1_Base/model_1200000.safetensors")
26+
v1_120_model_path = Path(
27+
"./models/F5-TTS/F5TTS_v1_Base/model_1200000.safetensors"
28+
)
29+
# NOTE: F5TTS 在 2025三月份发布的新 v1_base 模型
30+
v1_125_model_path = Path(
31+
"./models/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors"
32+
)
33+
v1_model_path = (
34+
v1_125_model_path if v1_125_model_path.exists() else v1_120_model_path
35+
)
2736

2837
self.model_path = v1_model_path
2938
self.model_version = "F5TTS_v1_Base"
@@ -55,12 +64,29 @@ def check_files(self) -> None:
5564
def is_downloaded(self) -> bool:
5665
return self.model_path.exists() and self.vocos_path.exists()
5766

67+
def _alert_ckpt_files(self):
68+
"""
69+
提醒下载模型
70+
"""
71+
if not self.model_path.exists():
72+
# NOTE: 不存在,提示使用脚本下载 `python -m scripts.downloader.f5_tts_v1 --source huggingface`
73+
if self.model_version == "F5TTS_Base":
74+
raise ValueError(
75+
f"F5TTS model {self.model_version} not found, please download it manually using `python -m scripts.downloader.f5_tts --source huggingface`"
76+
)
77+
else:
78+
raise ValueError(
79+
f"F5TTS model {self.model_version} not found, please download it manually using `python -m scripts.downloader.f5_tts_v1 --source huggingface`"
80+
)
81+
5882
def load(self) -> F5TTS:
83+
self._alert_ckpt_files()
5984
self.check_files()
6085

6186
with self.load_lock:
6287
if self.model is None:
6388
self.model = F5TTS(
89+
ckpt_file=self.model_path,
6490
model=self.model_version,
6591
vocoder_local_path=str(self.vocos_path),
6692
device=self.get_device(),

scripts/downloader/f5_tts.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def __init__(self):
3131
modelscope_repo="AI-ModelScope/F5-TTS",
3232
huggingface_repo="SWivid/F5-TTS",
3333
required_files=required_files,
34+
just_download_required_files=True,
3435
)
3536

3637
self.logger = logger

scripts/downloader/f5_tts_v1.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def __init__(self):
3131
modelscope_repo="AI-ModelScope/F5-TTS",
3232
huggingface_repo="SWivid/F5-TTS",
3333
required_files=required_files,
34+
just_download_required_files=True,
3435
)
3536

3637
self.logger = logger
@@ -60,7 +61,7 @@ def from_modelscope(self):
6061

6162
# v1 版本是 2025/3 发布的新模型
6263
url = "https://modelscope.cn/models/AI-ModelScope/F5-TTS/resolve/master/F5TTS_v1_Base/model_1250000.safetensors"
63-
dest_path = self.model_dir / "F5TTS_v1_Base" / "model_1200000.safetensors"
64+
dest_path = self.model_dir / "F5TTS_v1_Base" / "model_1250000.safetensors"
6465

6566
self._download_with_progress(url, dest_path)
6667

0 commit comments

Comments
 (0)