Skip to content

Commit 97ee616

Browse files
authored
add ipo, hinge and cpo loss to dpo trainer (#6788)
add ipo and hinge loss to dpo trainer
1 parent 0fc62d1 commit 97ee616

File tree

1 file changed

+19
-6
lines changed

1 file changed

+19
-6
lines changed

examples/research_projects/diffusion_dpo/train_diffusion_dpo.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -299,9 +299,15 @@ def parse_args(input_args=None):
299299
parser.add_argument(
300300
"--beta_dpo",
301301
type=int,
302-
default=5000,
302+
default=2500,
303303
help="DPO KL Divergence penalty.",
304304
)
305+
parser.add_argument(
306+
"--loss_type",
307+
type=str,
308+
default="sigmoid",
309+
help="DPO loss type. Can be one of 'sigmoid' (default), 'ipo', or 'cpo'",
310+
)
305311
parser.add_argument(
306312
"--learning_rate",
307313
type=float,
@@ -858,12 +864,19 @@ def collate_fn(examples):
858864
accelerator.unwrap_model(unet).enable_adapters()
859865

860866
# Final loss.
861-
scale_term = -0.5 * args.beta_dpo
862-
inside_term = scale_term * (model_diff - ref_diff)
863-
loss = -1 * F.logsigmoid(inside_term).mean()
867+
logits = ref_diff - model_diff
868+
if args.loss_type == "sigmoid":
869+
loss = -1 * F.logsigmoid(args.beta_dpo * logits).mean()
870+
elif args.loss_type == "hinge":
871+
loss = torch.relu(1 - args.beta_dpo * logits).mean()
872+
elif args.loss_type == "ipo":
873+
losses = (logits - 1 / (2 * args.beta)) ** 2
874+
loss = losses.mean()
875+
else:
876+
raise ValueError(f"Unknown loss type {args.loss_type}")
864877

865-
implicit_acc = (inside_term > 0).sum().float() / inside_term.size(0)
866-
implicit_acc += 0.5 * (inside_term == 0).sum().float() / inside_term.size(0)
878+
implicit_acc = (logits > 0).sum().float() / logits.size(0)
879+
implicit_acc += 0.5 * (logits == 0).sum().float() / logits.size(0)
867880

868881
accelerator.backward(loss)
869882
if accelerator.sync_gradients:

0 commit comments

Comments
 (0)