Skip to content

Commit 814f3d2

Browse files
committed
support gpt training
1 parent 3516659 commit 814f3d2

File tree

7 files changed

+296
-18
lines changed

7 files changed

+296
-18
lines changed

configs/gpt.yaml

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
train:
2+
seed: 1234
3+
epochs: 20
4+
batch_size: 8
5+
save_every_n_epoch: 1
6+
precision: 16-mixed
7+
gradient_clip: 1.0
8+
optimizer:
9+
lr: 0.01
10+
lr_init: 0.00001
11+
lr_end: 0.0001
12+
warmup_steps: 2000
13+
decay_steps: 40000
14+
data:
15+
max_eval_sample: 8
16+
max_sec: 54
17+
num_workers: 4
18+
pad_val: 1024 # same with EOS in model
19+
model:
20+
vocab_size: 1025
21+
phoneme_vocab_size: 732
22+
embedding_dim: 512
23+
hidden_dim: 512
24+
head: 16
25+
linear_units: 2048
26+
n_layer: 24
27+
dropout: 0
28+
EOS: 1024
29+
random_bert: 0
30+
inference:
31+
top_k: 15

requirements.txt

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
torch
2-
tqdm
3-
six
1+
torch~=2.5.1
2+
tqdm~=4.67.1
3+
six~=1.17.0
44
numpy==1.23.5
55
audioread==3.0.0
66
librosa==0.9.2
@@ -12,33 +12,46 @@ beartype==0.14.1
1212
rotary_embedding_torch==0.3.5
1313
einops==0.6.0
1414
librosa==0.9.2
15-
pytorch-lightning
16-
transformers
17-
ffmpeg-python
15+
pytorch-lightning~=2.0.0
16+
transformers~=4.47.1
17+
ffmpeg-python~=0.2.0
1818
onnxruntime; sys_platform == 'darwin'
1919
onnxruntime-gpu; sys_platform != 'darwin'
20-
fastapi
21-
uvicorn
22-
pydantic
20+
fastapi~=0.115.6
21+
uvicorn~=0.34.0
22+
pydantic~=2.10.5
2323
LangSegment>=0.2.0
24-
pypinyin
24+
pypinyin~=0.53.0
2525
jieba_fast
2626
opencc; sys_platform != 'linux'
2727
opencc==1.1.1; sys_platform == 'linux'
2828
pyopenjtalk>=0.3.4
29-
nltk
30-
wordsegment
29+
nltk~=3.9.1
30+
wordsegment~=1.3.1
3131
g2p_en
32-
jamo
32+
jamo~=0.4.1
3333
ko_pron
34-
g2pk2
35-
pyjyutping
36-
cn2an
34+
g2pk2~=0.0.3
35+
pyjyutping~=1.0.0
36+
cn2an~=0.5.23
3737
modelscope==1.10.0
3838
sentencepiece
3939
Faster_Whisper
4040
funasr==1.0.27
4141
torchaudio
4242
python-mecab-ko
43-
opencc
44-
matplotlib
43+
opencc~=1.1.9
44+
matplotlib~=3.9.4
45+
46+
PyYAML~=6.0.2
47+
joblib~=1.4.2
48+
faster-whisper~=1.1.1
49+
soundfile~=0.13.0
50+
packaging~=24.2
51+
rotary-embedding-torch~=0.3.5
52+
onnxruntime~=1.19.2
53+
requests~=2.32.3
54+
g2p-en~=2.1.0
55+
pandas~=2.2.3
56+
torchmetrics~=1.6.1
57+
regex~=2024.11.6

src/service/train.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#!/usr/bin/env python
2+
# -*- encoding=utf8 -*-
3+
import traceback
4+
5+
from src.train.gpt import GPTTrainParams, GPTTrain
6+
from src.utils.response import EaseVoiceResponse, ResponseStatus
7+
8+
9+
class TrainService(object):
10+
def __init__(self, gpt_params: GPTTrainParams):
11+
self.gpt_train = GPTTrain(gpt_params)
12+
13+
def train(self) -> EaseVoiceResponse:
14+
try:
15+
self.gpt_train.train()
16+
return EaseVoiceResponse(ResponseStatus.SUCCESS, "Training completed successfully")
17+
except Exception as e:
18+
print(traceback.format_exc(), e)
19+
return EaseVoiceResponse(ResponseStatus.FAILED, str(e))

src/train/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
#!/usr/bin/env python
2+
# -*- encoding=utf8 -*-

src/train/gpt.py

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
#!/usr/bin/env python
2+
# -*- encoding=utf8 -*-
3+
4+
import logging
5+
import os
6+
import platform
7+
import re
8+
from collections import OrderedDict
9+
from dataclasses import dataclass
10+
from pathlib import Path
11+
12+
import torch
13+
import yaml
14+
from pytorch_lightning import Trainer
15+
from pytorch_lightning import seed_everything
16+
from pytorch_lightning.callbacks import ModelCheckpoint
17+
from pytorch_lightning.loggers import TensorBoardLogger
18+
from pytorch_lightning.strategies import DDPStrategy
19+
20+
from src.easevoice.soundstorm.auto_reg.data.data_module import Text2SemanticDataModule
21+
from src.easevoice.soundstorm.auto_reg.models.t2s_lightning_module import Text2SemanticLightningModule
22+
from src.utils.config import gpt_config_path, train_output, cfg, gpt_pretrained_model_path, semantic_output, text_output_name, train_gpt_logs_output
23+
24+
25+
@dataclass
26+
class GPTTrainParams:
27+
batch_size: int = 12
28+
total_epochs: int = 15
29+
save_every_epoch: int = 5
30+
if_dpo: bool = False
31+
if_save_latest: bool = True
32+
if_save_every_weights: bool = True
33+
gpu_ids: str = "0"
34+
model_path: str = gpt_pretrained_model_path
35+
processing_path: str = ""
36+
normalize_path: str = ""
37+
output_model_name: str = "gpt"
38+
39+
40+
class GPTCheckpoint(ModelCheckpoint):
41+
def __init__(
42+
self,
43+
config,
44+
if_save_latest,
45+
if_save_every_weights,
46+
half_weights_save_dir,
47+
output_name,
48+
**kwargs
49+
):
50+
super().__init__(**kwargs)
51+
self.if_save_latest = if_save_latest
52+
self.if_save_every_weights = if_save_every_weights
53+
self.half_weights_save_dir = half_weights_save_dir
54+
self.output_name = output_name
55+
self.config = config
56+
57+
def on_train_epoch_end(self, trainer, pl_module):
58+
if self._should_save_on_train_epoch_end(trainer):
59+
monitor_candidates = self._monitor_candidates(trainer)
60+
if (
61+
self._every_n_epochs >= 1
62+
and (trainer.current_epoch + 1) % self._every_n_epochs == 0
63+
):
64+
to_clean = []
65+
if (
66+
self.if_save_latest == True
67+
):
68+
to_clean = list(os.listdir(self.dirpath))
69+
self._save_topk_checkpoint(trainer, monitor_candidates)
70+
if self.if_save_latest:
71+
for name in to_clean:
72+
try:
73+
os.remove("%s/%s" % (self.dirpath, name))
74+
except:
75+
pass
76+
if self.if_save_every_weights:
77+
to_save_od = OrderedDict()
78+
to_save_od["weight"] = OrderedDict()
79+
state_dict = trainer.strategy._lightning_module.state_dict()
80+
for key in state_dict:
81+
to_save_od["weight"][key] = state_dict[key].half()
82+
to_save_od["config"] = self.config
83+
to_save_od["info"] = "GPT-e%s" % (trainer.current_epoch + 1)
84+
if os.environ.get("LOCAL_RANK", "0") == "0":
85+
new_path = os.path.join(
86+
self.half_weights_save_dir,
87+
"%s-e%s.ckpt" % (self.output_name, trainer.current_epoch + 1),
88+
)
89+
torch.save(to_save_od, new_path)
90+
self._save_last_checkpoint(trainer, monitor_candidates)
91+
92+
93+
class GPTTrain(object):
94+
def __init__(self, params: GPTTrainParams):
95+
logging.getLogger("numba").setLevel(logging.WARNING)
96+
logging.getLogger("matplotlib").setLevel(logging.WARNING)
97+
torch.set_float32_matmul_precision("high")
98+
with open(gpt_config_path, "r") as f:
99+
data = f.read()
100+
self.config = yaml.load(data, Loader=yaml.FullLoader)
101+
self.processing_path = params.processing_path
102+
self.normalize_path = params.normalize_path
103+
self.train_output = os.path.join(params.normalize_path, train_output)
104+
self.train_logs_output = os.path.join(params.normalize_path, train_gpt_logs_output)
105+
self.train_ckpts_output = os.path.join(self.train_logs_output, "ckpt")
106+
os.makedirs(self.train_output, exist_ok=True)
107+
os.makedirs(self.train_logs_output, exist_ok=True)
108+
os.makedirs(self.train_ckpts_output, exist_ok=True)
109+
self.cfg = cfg
110+
if not self.cfg.is_half:
111+
self.config["train"]["precision"] = "32"
112+
params.batch_size = max(1, params.batch_size // 2)
113+
self.config["train"]["batch_size"] = params.batch_size
114+
self.config["train"]["epochs"] = params.total_epochs
115+
self.config["train"]["save_every_n_epoch"] = params.save_every_epoch
116+
self.config["train"]["if_dpo"] = params.if_dpo
117+
self.config["train"]["if_save_latest"] = params.if_save_latest
118+
self.config["train"]["if_save_every_weights"] = params.if_save_every_weights
119+
self.config["pretrained_s1"] = params.model_path
120+
self.config["train"]["half_weights_save_dir"] = self.train_output
121+
self.config["train_semantic_path"] = os.path.join(params.normalize_path, semantic_output)
122+
self.config["train_phoneme_path"] = os.path.join(params.normalize_path, text_output_name)
123+
self.config["logs_output_dir"] = self.train_logs_output
124+
self.config["train"]["output_name"] = params.output_model_name
125+
os.environ["hz"] = "25hz"
126+
seed_everything(self.config["train"]["seed"], workers=True)
127+
ckpt_callback: ModelCheckpoint = GPTCheckpoint(
128+
config=self.config,
129+
if_save_latest=self.config["train"]["if_save_latest"],
130+
if_save_every_weights=self.config["train"]["if_save_every_weights"],
131+
half_weights_save_dir=self.config["train"]["half_weights_save_dir"],
132+
output_name=self.config["train"]["output_name"],
133+
save_top_k=-1,
134+
monitor="top_3_acc",
135+
mode="max",
136+
save_on_train_epoch_end=True,
137+
every_n_epochs=self.config["train"]["save_every_n_epoch"],
138+
dirpath=self.train_ckpts_output,
139+
)
140+
logger = TensorBoardLogger(name="log", save_dir=self.train_logs_output)
141+
os.environ["MASTER_ADDR"] = "localhost"
142+
self.trainer: Trainer = Trainer(
143+
max_epochs=self.config["train"]["epochs"],
144+
accelerator=self.cfg.device,
145+
limit_val_batches=0,
146+
devices=-1 if torch.cuda.is_available() else 1,
147+
benchmark=False,
148+
fast_dev_run=False,
149+
strategy=DDPStrategy(
150+
process_group_backend="nccl" if platform.system() != "Windows" else "gloo"
151+
) if torch.cuda.is_available() else "auto",
152+
precision=self.config["train"]["precision"],
153+
logger=logger,
154+
num_sanity_val_steps=0,
155+
callbacks=[ckpt_callback],
156+
use_distributed_sampler=False,
157+
)
158+
self.model: Text2SemanticLightningModule = Text2SemanticLightningModule(
159+
self.config, Path(self.train_logs_output)
160+
)
161+
162+
self.data_module: Text2SemanticDataModule = Text2SemanticDataModule(
163+
self.config,
164+
train_semantic_path=self.config["train_semantic_path"],
165+
train_phoneme_path=self.config["train_phoneme_path"],
166+
)
167+
trainer_ckpt_path = self._get_newest_ckpt(os.listdir(self.train_ckpts_output))
168+
self.trainer_ckpt_path = os.path.join(self.train_ckpts_output, trainer_ckpt_path) if trainer_ckpt_path else None
169+
170+
def train(self):
171+
self.trainer.fit(self.model, self.data_module, ckpt_path=self.trainer_ckpt_path)
172+
173+
@staticmethod
174+
def _get_newest_ckpt(file_list: []):
175+
if file_list is None or len(file_list) == 0:
176+
return None
177+
178+
pattern = r'epoch=(\d+)-step=(\d+)\.ckpt'
179+
extracted_info = []
180+
for string in file_list:
181+
match = re.match(pattern, string)
182+
if match:
183+
epoch = int(match.group(1))
184+
step = int(match.group(2))
185+
extracted_info.append((epoch, step, string))
186+
sorted_info = sorted(
187+
extracted_info, key=lambda x: (x[0], x[1]), reverse=True)
188+
newest_ckpt = sorted_info[0][2]
189+
return newest_ckpt

src/utils/config/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@
3232
wav_output = "5-wav32k"
3333
semantic_output = "6-name2semantic.tsv"
3434
s2config_path = os.path.join(base_path, "configs", "s2.json")
35+
gpt_config_path = os.path.join(base_path, "configs", "gpt.yaml")
36+
train_output = "train"
37+
train_gpt_logs_output = "gpt_logs"
38+
gpt_pretrained_model_path = os.path.join(normalize_root, "gsv-v2final-pretrained", "s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt")
3539
cfg = config.GlobalCFG()
3640

3741
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"

tests/train_test.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#!/usr/bin/env python
2+
# -*- encoding=utf8 -*-
3+
4+
import unittest
5+
6+
from src.service.train import TrainService
7+
from src.train.gpt import GPTTrainParams
8+
from src.utils.response import ResponseStatus
9+
10+
11+
class TestTrain(unittest.TestCase):
12+
service = TrainService(gpt_params=GPTTrainParams(
13+
processing_path="./output",
14+
normalize_path="./output/test",
15+
output_model_name="test",
16+
))
17+
18+
def test_train(self):
19+
resp = self.service.train()
20+
self.assertEqual(resp.status, ResponseStatus.SUCCESS)

0 commit comments

Comments
 (0)