Skip to content

Commit 6385656

Browse files
committed
update dpo
1 parent cc234bd commit 6385656

File tree

23 files changed

+345
-2024
lines changed

23 files changed

+345
-2024
lines changed

cosyvoice/bin/inference.py renamed to cosyvoice/bin/inference_deprecated.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,4 +122,5 @@ def main():
122122

123123

124124
if __name__ == '__main__':
125+
logging.warning('this code has been deprecated, please refer to README for CosyVoice inference usage!')
125126
main()

cosyvoice/bin/train.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
from torch.distributed.elastic.multiprocessing.errors import record
2929

30+
from cosyvoice.utils.losses import DPOLoss
3031
from cosyvoice.utils.executor import Executor
3132
from cosyvoice.utils.train_utils import (
3233
init_distributed,
@@ -43,6 +44,7 @@ def get_args():
4344
choices=['torch_ddp', 'deepspeed'],
4445
help='Engine for paralleled training')
4546
parser.add_argument('--model', required=True, help='model which will be trained')
47+
parser.add_argument('--ref_model', required=False, help='ref model used in dpo')
4648
parser.add_argument('--config', required=True, help='config file')
4749
parser.add_argument('--train_data', required=True, help='train data file')
4850
parser.add_argument('--cv_data', required=True, help='cv data file')
@@ -73,6 +75,10 @@ def get_args():
7375
action='store_true',
7476
default=False,
7577
help='Use automatic mixed precision training')
78+
parser.add_argument('--dpo',
79+
action='store_true',
80+
default=False,
81+
help='Use Direct Preference Optimization')
7682
parser.add_argument('--deepspeed.save_states',
7783
dest='save_states',
7884
default='model_only',
@@ -113,7 +119,7 @@ def main():
113119

114120
# Get dataset & dataloader
115121
train_dataset, cv_dataset, train_data_loader, cv_data_loader = \
116-
init_dataset_and_dataloader(args, configs, gan)
122+
init_dataset_and_dataloader(args, configs, gan, args.dpo)
117123

118124
# Do some sanity checks and save config to arsg.model_dir
119125
configs = check_modify_and_save_config(args, configs)
@@ -122,6 +128,8 @@ def main():
122128
writer = init_summarywriter(args)
123129

124130
# load checkpoint
131+
if args.dpo is True:
132+
configs[args.model].forward = configs[args.model].forward_dpo
125133
model = configs[args.model]
126134
start_step, start_epoch = 0, -1
127135
if args.checkpoint is not None:
@@ -150,13 +158,25 @@ def main():
150158
info_dict['epoch'] = start_epoch
151159
save_model(model, 'init', info_dict)
152160

161+
# DPO related
162+
if args.dpo is True:
163+
ref_model = deepcopy(configs[args.model])
164+
state_dict = torch.load(args.ref_model, map_location='cpu')
165+
ref_model.load_state_dict(state_dict, strict=False)
166+
dpo_loss = DPOLoss(beta=0.01, label_smoothing=0.0, ipo=False)
167+
# NOTE maybe it is not needed to wrap ref_model as ddp because its parameter is not updated
168+
ref_model = wrap_cuda_model(args, ref_model)
169+
else:
170+
ref_model, dpo_loss = None, None
171+
153172
# Get executor
154-
executor = Executor(gan=gan)
173+
executor = Executor(gan=gan, ref_model=ref_model, dpo_loss=dpo_loss)
155174
executor.step = start_step
156175

157176
# Init scaler, used for pytorch amp mixed precision training
158177
scaler = torch.cuda.amp.GradScaler() if args.use_amp else None
159178
print('start step {} start epoch {}'.format(start_step, start_epoch))
179+
160180
# Start training loop
161181
for epoch in range(start_epoch + 1, info_dict['max_epoch']):
162182
executor.epoch = epoch
@@ -167,7 +187,7 @@ def main():
167187
executor.train_one_epoc_gan(model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader,
168188
writer, info_dict, scaler, group_join)
169189
else:
170-
executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join)
190+
executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join, ref_model=ref_model)
171191
dist.destroy_process_group(group_join)
172192

173193

cosyvoice/bin/train_dpo.py

Lines changed: 0 additions & 187 deletions
This file was deleted.

cosyvoice/cli/model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def get_trt_kwargs(self):
103103
def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
104104
with self.llm_context, torch.cuda.amp.autocast(self.fp16 is True and hasattr(self.llm, 'vllm') is False):
105105
if isinstance(text, Generator):
106-
assert isinstance(self, CosyVoice2Model), 'streaming input text is only implemented for CosyVoice2!'
106+
assert isinstance(self, CosyVoice2Model) and not hasattr(self.llm, 'vllm'), 'streaming input text is only implemented for CosyVoice2 and do not support vllm!'
107107
for i in self.llm.inference_bistream(text=text,
108108
prompt_text=prompt_text.to(self.device),
109109
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
@@ -279,6 +279,7 @@ def load_vllm(self, model_dir):
279279
enable_prompt_embeds=True,
280280
gpu_memory_utilization=0.2)
281281
self.llm.vllm = LLMEngine.from_engine_args(engine_args)
282+
self.llm.lock = threading.Lock()
282283
del self.llm.llm.model.model.layers
283284

284285
def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0):

cosyvoice/dataset/dataset.py

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,13 @@
1414
# limitations under the License.
1515

1616
import random
17-
import json
1817
import math
1918
from functools import partial
2019

2120
import torch
2221
import torch.distributed as dist
2322
from torch.utils.data import IterableDataset
24-
from cosyvoice.utils.file_utils import read_lists, read_json_lists
23+
from cosyvoice.utils.file_utils import read_lists
2524

2625

2726
class Processor(IterableDataset):
@@ -127,10 +126,9 @@ def Dataset(data_list_file,
127126
data_pipeline,
128127
mode='train',
129128
gan=False,
129+
dpo=False,
130130
shuffle=True,
131-
partition=True,
132-
tts_file='',
133-
prompt_utt2data=''):
131+
partition=True):
134132
""" Construct dataset from arguments
135133
136134
We have two shuffle stage in the Dataset. The first is global
@@ -142,23 +140,12 @@ def Dataset(data_list_file,
142140
tokenizer (BaseTokenizer): tokenizer to tokenize
143141
partition(bool): whether to do data partition in terms of rank
144142
"""
145-
assert mode in ['train', 'inference']
146143
lists = read_lists(data_list_file)
147-
if mode == 'inference':
148-
with open(tts_file) as f:
149-
tts_data = json.load(f)
150-
utt2lists = read_json_lists(prompt_utt2data)
151-
# filter unnecessary file in inference mode
152-
lists = list({utt2lists[utt] for utt in tts_data.keys() if utt2lists[utt] in lists})
153144
dataset = DataList(lists,
154145
shuffle=shuffle,
155146
partition=partition)
156-
if mode == 'inference':
157-
# map partial arg to parquet_opener func in inference mode
158-
data_pipeline[0] = partial(data_pipeline[0], tts_data=tts_data)
159-
if gan is True:
160-
# map partial arg to padding func in gan mode
161-
data_pipeline[-1] = partial(data_pipeline[-1], gan=gan)
147+
# map partial arg to padding func
148+
data_pipeline[-1] = partial(data_pipeline[-1], gan=gan, dpo=dpo)
162149
for func in data_pipeline:
163150
dataset = Processor(dataset, func, mode=mode)
164151
return dataset

0 commit comments

Comments
 (0)