Skip to content

Commit 7798c3f

Browse files
Manan17Manan Shahvaibhavjindal
authored
Adding support for apo losses, sppo_hard and nca_pair (#841)
## Summary <!--- This is a required section; please describe the main purpose of this proposed code change. ---> This pr adds support for apo zero, apo down, sppo_hard and nca_pair loss just like in (https://github.com/huggingface/trl/blob/main/trl/trainer/dpo_trainer.py). <!--- ## Details This is an optional section; is there anything specific that reviewers should be aware of? ---> ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> python -m pytest test/chunked_loss/test_dpo_loss.py <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: H100 - [ ] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence --------- Co-authored-by: Manan Shah <[email protected]> Co-authored-by: Vaibhav Jindal <[email protected]>
1 parent c5aa4d2 commit 7798c3f

File tree

2 files changed

+634
-3
lines changed

2 files changed

+634
-3
lines changed

src/liger_kernel/chunked_loss/dpo_loss.py

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ def preference_loss_fn(
1313
ref_chosen_logps=None,
1414
ref_rejected_logps=None,
1515
beta=0.1,
16+
loss_type="sigmoid",
1617
):
1718
"""
1819
Paper: https://arxiv.org/pdf/2305.18290
@@ -48,8 +49,50 @@ def preference_loss_fn(
4849
chosen_rewards = beta * chosen_logratios
4950
rejected_rewards = beta * rejected_logratios
5051

51-
logits_diff = beta * (chosen_logratios - rejected_logratios)
52-
loss = -F.logsigmoid(logits_diff).sum() / (full_target.shape[0] // 2)
52+
if loss_type == "sigmoid":
53+
logits_diff = beta * (chosen_logratios - rejected_logratios)
54+
loss = -F.logsigmoid(logits_diff).sum() / (full_target.shape[0] // 2)
55+
56+
elif loss_type == "apo_zero":
57+
# Eqn (7) of the APO paper (https://huggingface.co/papers/2408.06266)
58+
# Use this loss when you believe the chosen outputs are better than your model's default output
59+
losses_chosen = 1 - F.sigmoid(beta * chosen_logratios) # Increase chosen likelihood
60+
losses_rejected = F.sigmoid(beta * rejected_logratios)
61+
losses = losses_chosen + losses_rejected
62+
loss = losses.sum() / (full_target.shape[0] // 2)
63+
64+
elif loss_type == "apo_down":
65+
# Eqn (8) of the APO paper (https://huggingface.co/papers/2408.06266)
66+
# Use this loss when you believe the chosen outputs are worse than your model's default output.
67+
# Decrease chosen likelihood and decrease rejected likelihood more
68+
losses_chosen = F.sigmoid(beta * chosen_logratios)
69+
losses_rejected = 1 - F.sigmoid(beta * (chosen_logratios - rejected_logratios))
70+
losses = losses_chosen + losses_rejected
71+
loss = losses.sum() / (full_target.shape[0] // 2)
72+
73+
elif loss_type == "sppo_hard":
74+
# In the paper (https://huggingface.co/papers/2405.00675), SPPO employs a soft probability approach,
75+
# estimated using the PairRM score. The probability calculation is conducted outside of the trainer class.
76+
# The version described here is the hard probability version, where P in Equation (4.7) of Algorithm 1 is
77+
# set to 1 for the winner and 0 for the loser.
78+
a = chosen_logps - ref_chosen_logps
79+
b = rejected_logps - ref_rejected_logps
80+
losses = (a - 0.5 / beta) ** 2 + (b + 0.5 / beta) ** 2
81+
loss = losses.sum() / (full_target.shape[0] // 2)
82+
83+
elif loss_type == "nca_pair":
84+
losses = (
85+
-F.logsigmoid(chosen_rewards)
86+
- 0.5 * F.logsigmoid(-chosen_rewards)
87+
- 0.5 * F.logsigmoid(-rejected_rewards)
88+
)
89+
loss = losses.sum() / (full_target.shape[0] // 2)
90+
91+
else:
92+
raise ValueError(
93+
f"Unsupported loss_type: {loss_type}. Supported types are: sigmoid, apo_zero, apo_down, sppo_hard, nca_pair"
94+
)
95+
5396
return loss, chosen_rewards, rejected_rewards
5497

5598
@classmethod
@@ -70,6 +113,7 @@ def forward(
70113
use_ref_model=True,
71114
average_log_prob=False,
72115
chunk_size=1,
116+
loss_type="sigmoid",
73117
):
74118
"""
75119
Fused linear layer with DPO loss.
@@ -108,12 +152,13 @@ def forward(
108152
ref_bias=ref_bias,
109153
average_log_prob=average_log_prob,
110154
chunk_size=chunk_size,
155+
loss_type=loss_type,
111156
)
112157

113158
@staticmethod
114159
def backward(ctx, *grad_output):
115160
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
116-
return *grads, None, None, None, None, None, None, None, None, None, None
161+
return *grads, None, None, None, None, None, None, None, None, None, None, None
117162

118163

119164
class LigerFusedLinearDPOLoss(torch.nn.Module):
@@ -130,6 +175,7 @@ def __init__(
130175
use_ref_model: bool = True,
131176
average_log_prob: bool = False,
132177
chunk_size: int = 1,
178+
loss_type: str = "sigmoid",
133179
):
134180
"""
135181
Args:
@@ -149,6 +195,10 @@ def __init__(
149195
self.use_ref_model = use_ref_model
150196
self.average_log_prob = average_log_prob
151197
self.chunk_size = chunk_size
198+
self.loss_type = loss_type
199+
supported_loss_types = {"sigmoid", "apo_zero", "apo_down", "sppo_hard", "nca_pair"}
200+
if self.loss_type not in supported_loss_types:
201+
raise ValueError(f"Unsupported loss_type: {self.loss_type}. Supported types are: {supported_loss_types}")
152202

153203
def forward(
154204
self,
@@ -175,4 +225,5 @@ def forward(
175225
self.use_ref_model,
176226
self.average_log_prob,
177227
self.chunk_size,
228+
self.loss_type,
178229
)

0 commit comments

Comments
 (0)