@@ -299,9 +299,15 @@ def parse_args(input_args=None):
299299 parser .add_argument (
300300 "--beta_dpo" ,
301301 type = int ,
302- default = 5000 ,
302+ default = 2500 ,
303303 help = "DPO KL Divergence penalty." ,
304304 )
305+ parser .add_argument (
306+ "--loss_type" ,
307+ type = str ,
308+ default = "sigmoid" ,
309+ help = "DPO loss type. Can be one of 'sigmoid' (default), 'ipo', or 'cpo'" ,
310+ )
305311 parser .add_argument (
306312 "--learning_rate" ,
307313 type = float ,
@@ -858,12 +864,19 @@ def collate_fn(examples):
858864 accelerator .unwrap_model (unet ).enable_adapters ()
859865
860866 # Final loss.
861- scale_term = - 0.5 * args .beta_dpo
862- inside_term = scale_term * (model_diff - ref_diff )
863- loss = - 1 * F .logsigmoid (inside_term ).mean ()
867+ logits = ref_diff - model_diff
868+ if args .loss_type == "sigmoid" :
869+ loss = - 1 * F .logsigmoid (args .beta_dpo * logits ).mean ()
870+ elif args .loss_type == "hinge" :
871+ loss = torch .relu (1 - args .beta_dpo * logits ).mean ()
872+ elif args .loss_type == "ipo" :
873+ losses = (logits - 1 / (2 * args .beta )) ** 2
874+ loss = losses .mean ()
875+ else :
876+ raise ValueError (f"Unknown loss type { args .loss_type } " )
864877
865- implicit_acc = (inside_term > 0 ).sum ().float () / inside_term .size (0 )
866- implicit_acc += 0.5 * (inside_term == 0 ).sum ().float () / inside_term .size (0 )
878+ implicit_acc = (logits > 0 ).sum ().float () / logits .size (0 )
879+ implicit_acc += 0.5 * (logits == 0 ).sum ().float () / logits .size (0 )
867880
868881 accelerator .backward (loss )
869882 if accelerator .sync_gradients :
0 commit comments