11
22import dataclasses
3+
4+ from ...utils .config .config import GlobalCFG
5+ from ...utils .path import get_base_path
36from .preprocessor import TextPreprocessor
47from .segmentation import SPLITS
5- from module .mel_processing import spectrogram_torch
8+ from .. module .mel_processing import spectrogram_torch
69from ...utils .audio import load_audio
710from time import time as ttime
811import librosa
9- from module .models import SynthesizerTrn
12+ from .. module .models import SynthesizerTrn
1013from ..feature_extractor .cnhubert import CNHubert
1114from ..soundstorm .auto_reg .models .t2s_lightning_module import Text2SemanticLightningModule
1215from transformers import AutoModelForMaskedLM , AutoTokenizer
@@ -54,16 +57,21 @@ def set_seed(seed: int):
5457 return seed
5558
5659
60+ def _get_default_configs ():
61+ global_config = GlobalCFG ()
62+ return {
63+ "device" : global_config .device ,
64+ "is_half" : global_config .is_half ,
65+ "t2s_weights_path" : global_config .gpt_path ,
66+ "vits_weights_path" : global_config .sovits_path ,
67+ "cnhuhbert_base_path" : global_config .cnhubert_path ,
68+ "bert_base_path" : global_config .bert_path ,
69+ }
70+
71+
5772class TTSConfig :
5873 default_configs = {
59- "default" : {
60- "device" : "cpu" ,
61- "is_half" : False ,
62- "t2s_weights_path" : "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt" ,
63- "vits_weights_path" : "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth" ,
64- "cnhuhbert_base_path" : "GPT_SoVITS/pretrained_models/chinese-hubert-base" ,
65- "bert_base_path" : "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large" ,
66- },
74+ "default" : _get_default_configs (),
6775 }
6876 languages : list = ["auto" , "auto_yue" , "en" , "zh" , "ja" , "yue" , "ko" , "all_zh" , "all_ja" , "all_yue" , "all_ko" ]
6977 # "all_zh",#全部按中文识别
@@ -79,6 +87,8 @@ class TTSConfig:
7987 # "auto_yue",#多语种启动切分识别语种
8088
8189 def __init__ (self , configs : Union [dict , str , None ] = None ): # pyright: ignore
90+ global_config = GlobalCFG ()
91+
8292 configs_base_path = os .path .join (os .path .dirname (os .path .dirname (os .path .abspath (__file__ ))), "configs" )
8393 os .makedirs (configs_base_path , exist_ok = True )
8494 self .configs_path : str = os .path .join (configs_base_path , "tts_infer.yaml" )
@@ -97,8 +107,8 @@ def __init__(self, configs: Union[dict, str, None] = None): # pyright: ignore
97107 self .default_configs ["default" ] = configs .get ("default" , self .default_configs ["default" ])
98108
99109 self .configs : dict = configs .get ("custom" , deepcopy (self .default_configs ["default" ]))
100- self .device = self .configs .get ("device" , torch .device ( "cpu" ) )
101- self .is_half = self .configs .get ("is_half" , False )
110+ self .device = self .configs .get ("device" , global_config .device )
111+ self .is_half = self .configs .get ("is_half" , global_config . is_half )
102112
103113 def get_path (key : str ):
104114 path = self .configs .get (key , None )
@@ -180,7 +190,7 @@ def __init__(self, configs: Union[dict, str, TTSConfig]):
180190
181191 self .t2s_model : Text2SemanticLightningModule = None # pyright: ignore
182192 self .vits_model : SynthesizerTrn = None # pyright: ignore
183- self .bert_tokenizer : = None # pyright: ignore
193+ self .bert_tokenizer : AutoTokenizer = None # pyright: ignore
184194 self .bert_model : AutoModelForMaskedLM = None # pyright: ignore
185195 self .cnhuhbert_model : CNHubert = None # pyright: ignore
186196
@@ -223,7 +233,7 @@ def init_cnhuhbert_weights(self, base_path: str):
223233
224234 def init_bert_weights (self , base_path : str ):
225235 logger .info (f"Loading BERT weights from { base_path } " )
226- self .bert_tokenizer = AutoTokenizer .from_pretrained (base_path )
236+ self .bert_tokenizer = AutoTokenizer .from_pretrained (base_path ) # pyright: ignore
227237 self .bert_model = AutoModelForMaskedLM .from_pretrained (base_path )
228238 self .bert_model = self .bert_model .eval () # pyright: ignore
229239 self .bert_model = self .bert_model .to (self .configs .device ) # pyright: ignore
0 commit comments