Skip to content

Commit 39d857e

Browse files
committed
feat: SparkTTS support half mode
- SparkTTS 支持半精度,但是只允许 bf16 形式运行 - 增加 bf16 启动项
1 parent 1847b26 commit 39d857e

File tree

4 files changed

+44
-8
lines changed

4 files changed

+44
-8
lines changed

modules/core/models/tts/SparkTTS/SparkTTS.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
import numpy.typing as npt
2121
import torch
22-
from transformers import AutoModelForCausalLM, AutoTokenizer
22+
from transformers import AutoModelForCausalLM, AutoTokenizer, Qwen2ForCausalLM
2323

2424
from modules.repos_static.spark_tts.sparktts.models.audio_tokenizer import (
2525
BiCodecTokenizer,
@@ -37,7 +37,12 @@ class SparkTTS:
3737
Spark-TTS for text-to-speech generation.
3838
"""
3939

40-
def __init__(self, model_dir: Path, device: torch.device = torch.device("cuda:0")):
40+
def __init__(
41+
self,
42+
model_dir: Path,
43+
device: torch.device = torch.device("cuda:0"),
44+
dtype: torch.dtype = torch.float32,
45+
):
4146
"""
4247
Initializes the SparkTTS model with the provided configurations and device.
4348
@@ -46,6 +51,7 @@ def __init__(self, model_dir: Path, device: torch.device = torch.device("cuda:0"
4651
device (torch.device): The device (CPU/GPU) to run the model on.
4752
"""
4853
self.device = device
54+
self.dtype = dtype
4955
self.model_dir = model_dir
5056
self.configs = load_config(f"{model_dir}/config.yaml")
5157
self.sample_rate = self.configs["sample_rate"]
@@ -54,9 +60,12 @@ def __init__(self, model_dir: Path, device: torch.device = torch.device("cuda:0"
5460
def _initialize_inference(self):
5561
"""Initializes the tokenizer, model, and audio tokenizer for inference."""
5662
self.tokenizer = AutoTokenizer.from_pretrained(f"{self.model_dir}/LLM")
57-
self.model = AutoModelForCausalLM.from_pretrained(f"{self.model_dir}/LLM")
63+
self.model: Qwen2ForCausalLM = AutoModelForCausalLM.from_pretrained(
64+
f"{self.model_dir}/LLM"
65+
)
5866
self.audio_tokenizer = BiCodecTokenizer(self.model_dir, device=self.device)
59-
self.model.to(self.device)
67+
self.model.to(device=self.device, dtype=self.dtype)
68+
self.model.eval()
6069

6170
def process_prompt(
6271
self,

modules/core/models/tts/SparkTTS/SparkTTSModel.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Literal, Optional, Union
66

77
import soundfile as sf
8+
import torch
89

910
from modules.core.models.tts.SparkTTS.SparkTTS import SparkTTS
1011
from modules.core.models.TTSModel import TTSModel
@@ -38,11 +39,25 @@ def check_files(self) -> None:
3839
def is_downloaded(self) -> bool:
3940
return self.model_path.exists()
4041

42+
def get_dtype(self):
43+
dtype = super().get_dtype()
44+
if dtype == torch.float16:
45+
# NOTE: SparkTTS 模型对于 float16 精度很糟糕,几乎破坏了模型,无法运行
46+
# NOTE: 你可以使用 `--bf16` 启动项开启 bfloat16 模式,虽然可以运行,但是还是容易生成大量空白
47+
# NOTE: 所以,如果没有使用 bf16 又开启了 half ,那么将切换为 f32
48+
logger.warning(
49+
"检测到 dtype 为 float16,但 SparkTTS 对 float16 支持很差,已强制切换为 float32。"
50+
"建议使用 --bf16 开启 bfloat16 模式以获得更好兼容性。"
51+
)
52+
return torch.float32
53+
return dtype
54+
4155
def load(self):
4256
if self.model is None:
43-
# TODO: 配置 dtype
4457
self.model = SparkTTS(
45-
model_dir=str(self.model_path), device=self.get_device()
58+
model_dir=str(self.model_path),
59+
device=self.get_device(),
60+
dtype=self.get_dtype(),
4661
)
4762
return self.model
4863

modules/devices/devices.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,8 +145,12 @@ def reset_device():
145145
config.runtime_env_vars.no_half = True
146146

147147
if not config.runtime_env_vars.no_half:
148-
dtype = torch.float16
149-
logger.info("Using half precision: torch.float16")
148+
if config.runtime_env_vars.bf16:
149+
dtype = torch.bfloat16
150+
logger.info("Using half precision: torch.bfloat16")
151+
else:
152+
dtype = torch.float16
153+
logger.info("Using half precision: torch.float16")
150154
else:
151155
dtype = torch.float32
152156
logger.info("Using full precision: torch.float32")

modules/models_setup.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,18 @@ def setup_model_args(parser: argparse.ArgumentParser):
5656
action="store_true",
5757
help="Preload all models at startup",
5858
)
59+
# NOTE: 开启 ftc 等于给 torch 预热,但是服务冷启动变慢
5960
parser.add_argument(
6061
"--ftc",
6162
action="store_true",
6263
help="Enable first time calculation",
6364
)
65+
# NOTE: 不同模型可能有不同的适配度,比如 sparktts 只能使用 bfloat16 而不能使用 float16 ,所以某些模型半精度的情况需要开启这个
66+
parser.add_argument(
67+
"--bf16",
68+
action="store_true",
69+
help="Use bfloat16 as the data type when loading with half precision.",
70+
)
6471

6572

6673
def process_model_args(args: argparse.Namespace):
@@ -75,6 +82,7 @@ def process_model_args(args: argparse.Namespace):
7582
debug_generate = env.get_and_update_env(args, "debug_generate", False, bool)
7683
preload_models = env.get_and_update_env(args, "preload_models", False, bool)
7784
enable_ftc = env.get_and_update_env(args, "ftc", False, bool)
85+
bf16 = env.get_and_update_env(args, "bf16", False, bool)
7886

7987
# TODO: 需要等 zoo 模块实现
8088
# generate_audio.setup_lru_cache()

0 commit comments

Comments
 (0)