2727
2828from torch .distributed .elastic .multiprocessing .errors import record
2929
30+ from cosyvoice .utils .losses import DPOLoss
3031from cosyvoice .utils .executor import Executor
3132from cosyvoice .utils .train_utils import (
3233 init_distributed ,
@@ -43,6 +44,7 @@ def get_args():
4344 choices = ['torch_ddp' , 'deepspeed' ],
4445 help = 'Engine for paralleled training' )
4546 parser .add_argument ('--model' , required = True , help = 'model which will be trained' )
47+ parser .add_argument ('--ref_model' , required = False , help = 'ref model used in dpo' )
4648 parser .add_argument ('--config' , required = True , help = 'config file' )
4749 parser .add_argument ('--train_data' , required = True , help = 'train data file' )
4850 parser .add_argument ('--cv_data' , required = True , help = 'cv data file' )
@@ -73,6 +75,10 @@ def get_args():
7375 action = 'store_true' ,
7476 default = False ,
7577 help = 'Use automatic mixed precision training' )
78+ parser .add_argument ('--dpo' ,
79+ action = 'store_true' ,
80+ default = False ,
81+ help = 'Use Direct Preference Optimization' )
7682 parser .add_argument ('--deepspeed.save_states' ,
7783 dest = 'save_states' ,
7884 default = 'model_only' ,
@@ -113,7 +119,7 @@ def main():
113119
114120 # Get dataset & dataloader
115121 train_dataset , cv_dataset , train_data_loader , cv_data_loader = \
116- init_dataset_and_dataloader (args , configs , gan )
122+ init_dataset_and_dataloader (args , configs , gan , args . dpo )
117123
118124 # Do some sanity checks and save config to arsg.model_dir
119125 configs = check_modify_and_save_config (args , configs )
@@ -122,6 +128,8 @@ def main():
122128 writer = init_summarywriter (args )
123129
124130 # load checkpoint
131+ if args .dpo is True :
132+ configs [args .model ].forward = configs [args .model ].forward_dpo
125133 model = configs [args .model ]
126134 start_step , start_epoch = 0 , - 1
127135 if args .checkpoint is not None :
@@ -150,13 +158,25 @@ def main():
150158 info_dict ['epoch' ] = start_epoch
151159 save_model (model , 'init' , info_dict )
152160
161+ # DPO related
162+ if args .dpo is True :
163+ ref_model = deepcopy (configs [args .model ])
164+ state_dict = torch .load (args .ref_model , map_location = 'cpu' )
165+ ref_model .load_state_dict (state_dict , strict = False )
166+ dpo_loss = DPOLoss (beta = 0.01 , label_smoothing = 0.0 , ipo = False )
167+ # NOTE maybe it is not needed to wrap ref_model as ddp because its parameter is not updated
168+ ref_model = wrap_cuda_model (args , ref_model )
169+ else :
170+ ref_model , dpo_loss = None , None
171+
153172 # Get executor
154- executor = Executor (gan = gan )
173+ executor = Executor (gan = gan , ref_model = ref_model , dpo_loss = dpo_loss )
155174 executor .step = start_step
156175
157176 # Init scaler, used for pytorch amp mixed precision training
158177 scaler = torch .cuda .amp .GradScaler () if args .use_amp else None
159178 print ('start step {} start epoch {}' .format (start_step , start_epoch ))
179+
160180 # Start training loop
161181 for epoch in range (start_epoch + 1 , info_dict ['max_epoch' ]):
162182 executor .epoch = epoch
@@ -167,7 +187,7 @@ def main():
167187 executor .train_one_epoc_gan (model , optimizer , scheduler , optimizer_d , scheduler_d , train_data_loader , cv_data_loader ,
168188 writer , info_dict , scaler , group_join )
169189 else :
170- executor .train_one_epoc (model , optimizer , scheduler , train_data_loader , cv_data_loader , writer , info_dict , scaler , group_join )
190+ executor .train_one_epoc (model , optimizer , scheduler , train_data_loader , cv_data_loader , writer , info_dict , scaler , group_join , ref_model = ref_model )
171191 dist .destroy_process_group (group_join )
172192
173193
0 commit comments