-
Notifications
You must be signed in to change notification settings - Fork 21
Expand file tree
/
Copy pathtrain.py
More file actions
188 lines (137 loc) · 5.94 KB
/
train.py
File metadata and controls
188 lines (137 loc) · 5.94 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
import os
import sys
import torch
from transformers import Trainer, TrainingArguments
from safetensors.torch import save_file
from src.config import TrainConfig
from src.dataset import ChatterboxDataset, data_collator_turbo, data_collator_standart
from src.model import resize_and_load_t3_weights, ChatterboxTrainerWrapper
from src.preprocess_ljspeech import preprocess_dataset_ljspeech
from src.preprocess_file_based import preprocess_dataset_file_based
from src.preprocess_json import preprocess_dataset_json_based
from src.utils import setup_logger, check_pretrained_models
from src.inference_callback import InferenceCallback
from src.chatterbox_.tts import ChatterboxTTS
from src.chatterbox_.tts_turbo import ChatterboxTurboTTS
from src.chatterbox_.models.t3.t3 import T3
os.environ["TOKENIZERS_PARALLELISM"] = "false"
logger = setup_logger("ChatterboxFinetune")
def main():
cfg = TrainConfig()
logger.info("--- Starting Chatterbox Finetuning ---")
logger.info(f"Mode: {'CHATTERBOX-TURBO' if cfg.is_turbo else 'CHATTERBOX-TTS'}")
# 0. CHECK MODEL FILES
mode_check = "chatterbox_turbo" if cfg.is_turbo else "chatterbox"
if not check_pretrained_models(mode=mode_check):
sys.exit(1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 1. SELECT THE CORRECT ENGINE CLASS
if cfg.is_turbo:
EngineClass = ChatterboxTurboTTS
else:
EngineClass = ChatterboxTTS
logger.info(f"Device: {device}")
logger.info(f"Model Directory: {cfg.model_dir}")
# 2. LOAD ORIGINAL MODEL TEMPORARILY
logger.info("Loading original model to extract weights...")
# Loading on CPU first to save VRAM
tts_engine_original = EngineClass.from_local(cfg.model_dir, device="cpu")
pretrained_t3_state_dict = tts_engine_original.t3.state_dict()
original_t3_config = tts_engine_original.t3.hp
# 3. CREATE NEW T3 MODEL WITH NEW VOCAB SIZE
logger.info(f"Creating new T3 model with vocab size: {cfg.new_vocab_size}")
new_t3_config = original_t3_config
new_t3_config.text_tokens_dict_size = cfg.new_vocab_size
# We prevent caching during training.
if hasattr(new_t3_config, "use_cache"):
new_t3_config.use_cache = False
else:
setattr(new_t3_config, "use_cache", False)
new_t3_model = T3(hp=new_t3_config)
# 4. TRANSFER WEIGHTS
logger.info("Transferring weights...")
new_t3_model = resize_and_load_t3_weights(new_t3_model, pretrained_t3_state_dict)
# --- SPECIAL SETTING FOR TURBO ---
if cfg.is_turbo:
logger.info("Turbo Mode: Removing backbone WTE layer...")
if hasattr(new_t3_model.tfmr, "wte"):
del new_t3_model.tfmr.wte
# Clean up memory
del tts_engine_original
del pretrained_t3_state_dict
# 5. PREPARE ENGINE FOR TRAINING
# Reload engine components (VoiceEncoder, S3Gen) but inject our new T3
tts_engine_new = EngineClass.from_local(cfg.model_dir, device="cpu")
tts_engine_new.t3 = new_t3_model
# Freeze other components
logger.info("Freezing S3Gen and VoiceEncoder...")
for param in tts_engine_new.ve.parameters():
param.requires_grad = False
for param in tts_engine_new.s3gen.parameters():
param.requires_grad = False
# Enable Training for T3
tts_engine_new.t3.train()
for param in tts_engine_new.t3.parameters():
param.requires_grad = True
if cfg.preprocess:
logger.info("Initializing Preprocess dataset...")
if cfg.ljspeech:
preprocess_dataset_ljspeech(cfg, tts_engine_new)
elif cfg.json_format:
preprocess_dataset_json_based(cfg, tts_engine_new)
else:
preprocess_dataset_file_based(cfg, tts_engine_new)
else:
logger.info("Skipping the preprocessing dataset step...")
# 6. DATASET & WRAPPER
logger.info("Initializing Dataset...")
train_ds = ChatterboxDataset(cfg)
trainer_callbacks = []
if cfg.is_inference:
inference_cb = InferenceCallback(cfg)
trainer_callbacks.append(inference_cb)
model_wrapper = ChatterboxTrainerWrapper(tts_engine_new.t3)
if cfg.is_turbo:
logger.info("Using Turbo Data Collator (with dynamic prompt masking)")
selected_collator = data_collator_turbo
else:
logger.info("Using Standard Data Collator")
selected_collator = data_collator_standart
# 7. TRAINING ARGUMENTS
training_args = TrainingArguments(
output_dir=cfg.output_dir,
per_device_train_batch_size=cfg.batch_size,
gradient_accumulation_steps=cfg.grad_accum,
learning_rate=cfg.learning_rate,
num_train_epochs=cfg.num_epochs,
save_strategy="steps",
save_steps=cfg.save_steps,
logging_strategy="epoch",
remove_unused_columns=False, # Required for our custom wrapper
dataloader_num_workers=cfg.dataloader_num_workers,
report_to=["tensorboard"],
fp16=False,
bf16=True,
save_total_limit=cfg.save_total_limit,
gradient_checkpointing=True, # This setting theoretically reduces VRAM usage by 60%.
dataloader_persistent_workers=True,
dataloader_pin_memory=True,
)
trainer = Trainer(
model=model_wrapper,
args=training_args,
train_dataset=train_ds,
data_collator=selected_collator,
callbacks=trainer_callbacks
)
logger.info("Starting Training Loop...")
trainer.train()
# 8. SAVE FINAL MODEL
logger.info("Training complete. Saving model...")
os.makedirs(cfg.output_dir, exist_ok=True)
filename = "t3_turbo_finetuned.safetensors" if cfg.is_turbo else "t3_finetuned.safetensors"
final_model_path = os.path.join(cfg.output_dir, filename)
save_file(tts_engine_new.t3.state_dict(), final_model_path)
logger.info(f"Model saved to: {final_model_path}")
if __name__ == "__main__":
main()