Skip to content

Commit 9ef2cda

Browse files
committed
fix checkstyle
1 parent 10078c2 commit 9ef2cda

File tree

4 files changed

+140
-34
lines changed

4 files changed

+140
-34
lines changed
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction
22
from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction
3+
from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOFunction
34
from liger_kernel.chunked_loss.jsd_loss import LigerFusedLinearJSDFunction
45
from liger_kernel.chunked_loss.kto_loss import LigerFusedLinearKTOFunction
56
from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOFunction
67
from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOFunction
7-
from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOFunction
88

99
liger_fused_linear_orpo = LigerFusedLinearORPOFunction.apply
1010
liger_fused_linear_dpo = LigerFusedLinearDPOFunction.apply
1111
liger_fused_linear_jsd = LigerFusedLinearJSDFunction.apply
1212
liger_fused_linear_cpo = LigerFusedLinearCPOFunction.apply
1313
liger_fused_linear_simpo = LigerFusedLinearSimPOFunction.apply
1414
liger_fused_linear_kto = LigerFusedLinearKTOFunction.apply
15-
liger_fused_linear_grpo = LigerFusedLinearGRPOFunction.apply
15+
liger_fused_linear_grpo = LigerFusedLinearGRPOFunction.apply

src/liger_kernel/chunked_loss/fused_linear_rlhf.py

Lines changed: 52 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,14 @@ def forward(
8080
rlhf_loss_fn=cls.rlhf_loss_fn,
8181
)
8282

83-
def fused_fwd_bwd(input_chunk, selected_token_ids_chunk, attention_mask_chunk, advantages_chunk, ref_input_chunk, old_per_token_logps_chunk):
83+
def fused_fwd_bwd(
84+
input_chunk,
85+
selected_token_ids_chunk,
86+
attention_mask_chunk,
87+
advantages_chunk,
88+
ref_input_chunk,
89+
old_per_token_logps_chunk,
90+
):
8491
"""Fused forward and backward for a chunk."""
8592
argnums = (0, 1, 5) if bias is not None else (0, 1)
8693
return torch.func.grad_and_value(compute_loss, argnums=argnums, has_aux=True)(
@@ -94,9 +101,21 @@ def fused_fwd_bwd(input_chunk, selected_token_ids_chunk, attention_mask_chunk, a
94101
old_per_token_logps_chunk=old_per_token_logps_chunk, # arg 7
95102
)
96103

97-
def accumulate_chunk(input_chunk, selected_token_ids_chunk, attention_mask_chunk, advantages_chunk, ref_input_chunk=None, old_per_token_logps_chunk=None):
104+
def accumulate_chunk(
105+
input_chunk,
106+
selected_token_ids_chunk,
107+
attention_mask_chunk,
108+
advantages_chunk,
109+
ref_input_chunk=None,
110+
old_per_token_logps_chunk=None,
111+
):
98112
(chunk_grad_input, chunk_grad_weight, *chunk_grad_bias), (chunk_loss, chunk_metrics) = fused_fwd_bwd(
99-
input_chunk, selected_token_ids_chunk, attention_mask_chunk, advantages_chunk, ref_input_chunk, old_per_token_logps_chunk
113+
input_chunk,
114+
selected_token_ids_chunk,
115+
attention_mask_chunk,
116+
advantages_chunk,
117+
ref_input_chunk,
118+
old_per_token_logps_chunk,
100119
)
101120
if bias is not None:
102121
grad_bias.add_(chunk_grad_bias[0])
@@ -132,10 +151,26 @@ def accumulate_chunk(input_chunk, selected_token_ids_chunk, attention_mask_chunk
132151
_attention_mask_chunks = torch.chunk(attention_mask, chunks=chunks, dim=0)
133152
_advantages_chunks = torch.chunk(advantages, chunks=chunks, dim=0)
134153
_ref_input_chunks = torch.chunk(ref_input, chunks=chunks, dim=0) if use_ref_model else [None] * chunks
135-
_old_per_token_logps_chunks = torch.chunk(old_per_token_logps, chunks=chunks, dim=0) if old_per_token_logps is not None else [None] * chunks
154+
_old_per_token_logps_chunks = (
155+
torch.chunk(old_per_token_logps, chunks=chunks, dim=0)
156+
if old_per_token_logps is not None
157+
else [None] * chunks
158+
)
136159

137-
for input_chunk, selected_token_ids_chunk, attention_mask_chunk, advantages_chunk, ref_input_chunk, old_per_token_logps_chunk in zip(
138-
_input_chunks, _selected_token_ids_chunks, _attention_mask_chunks, _advantages_chunks, _ref_input_chunks, _old_per_token_logps_chunks
160+
for (
161+
input_chunk,
162+
selected_token_ids_chunk,
163+
attention_mask_chunk,
164+
advantages_chunk,
165+
ref_input_chunk,
166+
old_per_token_logps_chunk,
167+
) in zip(
168+
_input_chunks,
169+
_selected_token_ids_chunks,
170+
_attention_mask_chunks,
171+
_advantages_chunks,
172+
_ref_input_chunks,
173+
_old_per_token_logps_chunks,
139174
):
140175
# Mark dynamic dimensions
141176
torch._dynamo.mark_dynamic(input_chunk, 1)
@@ -150,7 +185,14 @@ def accumulate_chunk(input_chunk, selected_token_ids_chunk, attention_mask_chunk
150185
else:
151186
old_per_token_logps_chunk = None
152187

153-
accumulate_chunk(input_chunk, selected_token_ids_chunk, attention_mask_chunk, advantages_chunk, ref_input_chunk, old_per_token_logps_chunk)
188+
accumulate_chunk(
189+
input_chunk,
190+
selected_token_ids_chunk,
191+
attention_mask_chunk,
192+
advantages_chunk,
193+
ref_input_chunk,
194+
old_per_token_logps_chunk,
195+
)
154196

155197
# Combine gradients
156198
grad_input = torch.cat(grad_inputs, dim=0)
@@ -196,7 +238,9 @@ def _compute_chunk_loss(
196238
ref_log_probs = None
197239
if use_ref_model and ref_input_chunk is not None:
198240
with torch.no_grad():
199-
ref_log_probs, _ = LigerFusedLinearRLHFBase.chunk_forward(ref_input_chunk, ref_weight, bias=ref_bias, temperature=temperature)
241+
ref_log_probs, _ = LigerFusedLinearRLHFBase.chunk_forward(
242+
ref_input_chunk, ref_weight, bias=ref_bias, temperature=temperature
243+
)
200244

201245
# Compute chunk loss and metrics using the provided loss function
202246
chunk_loss, chunk_metrics = rlhf_loss_fn(

src/liger_kernel/chunked_loss/grpo_loss.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,9 @@ def rlhf_loss_fn(
4141
if beta != 0.0:
4242
# Compute KL penalty
4343
kl_div = (
44-
torch.exp(ref_token_logprobs - chosen_token_logprobs) - (ref_token_logprobs - chosen_token_logprobs) - 1.0
44+
torch.exp(ref_token_logprobs - chosen_token_logprobs)
45+
- (ref_token_logprobs - chosen_token_logprobs)
46+
- 1.0
4547
)
4648
# Combine losses
4749
per_token_loss = per_token_loss + beta * kl_div
@@ -58,7 +60,8 @@ def rlhf_loss_fn(
5860
]
5961
if beta != 0.0:
6062
metrics.append(
61-
((kl_div * attention_mask).sum(dim=1) / torch.clamp(attention_mask.sum(dim=1), min=1.0)).sum() / full_batch_size
63+
((kl_div * attention_mask).sum(dim=1) / torch.clamp(attention_mask.sum(dim=1), min=1.0)).sum()
64+
/ full_batch_size
6265
)
6366
return loss, metrics
6467

test/chunked_loss/test_grpo_loss.py

Lines changed: 81 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
import torch.nn.functional as F
44

55
from liger_kernel.chunked_loss import LigerFusedLinearGRPOLoss
6-
from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOFunction
76
from liger_kernel.chunked_loss.functional import liger_fused_linear_grpo
7+
from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOFunction
88
from liger_kernel.utils import infer_device
99
from test.utils import assert_verbose_allclose
1010
from test.utils import set_seed
@@ -16,6 +16,7 @@
1616
# reset torch compiler cache
1717
torch.compiler.reset()
1818

19+
1920
class TorchLMHeadGRPO(torch.nn.Module):
2021
def __init__(
2122
self,
@@ -38,7 +39,7 @@ def __init__(
3839
self.epsilon_high = epsilon_high
3940
self.temperature = temperature
4041
self.use_ref_model = use_ref_model
41-
42+
4243
def forward(
4344
self,
4445
x, # Shape: [batch_size, seq_len, hidden_size]
@@ -48,7 +49,7 @@ def forward(
4849
ref_input=None, # Shape: [batch_size, seq_len, hidden_size]
4950
old_per_token_logps=None,
5051
):
51-
logits = (x @ self.lin.weight.t())
52+
logits = x @ self.lin.weight.t()
5253
if self.lin.bias is not None:
5354
logits = logits + self.lin.bias
5455
if self.temperature != 1.0:
@@ -81,9 +82,7 @@ def forward(
8182
per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
8283
if self.beta != 0.0:
8384
# Compute KL divergence between model and reference model
84-
kl_div = (
85-
torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1.0
86-
)
85+
kl_div = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1.0
8786
per_token_loss = per_token_loss + self.beta * kl_div
8887

8988
# Apply masking and normalize
@@ -171,9 +170,9 @@ def forward(
171170
"beta, epsilon_low, epsilon_high, temperature",
172171
[
173172
# Standard settings
174-
(0.1, 0.2, 0.2, 20.0), # set temperature to 20.0 for better numerical stability
173+
(0.1, 0.2, 0.2, 20.0), # set temperature to 20.0 for better numerical stability
175174
(0.0, 0.1, 0.1, 2.0),
176-
]
175+
],
177176
)
178177
@pytest.mark.parametrize("use_ref_model", [True, False])
179178
@pytest.mark.parametrize("old_per_token_logps", [True, False])
@@ -231,7 +230,9 @@ def test_correctness(
231230
V, H, device=device, dtype=dtype
232231
)
233232
if ref_bias:
234-
torch_lm_head_grpo.ref_lin.bias.data = liger_lm_head_grpo.ref_lin.bias.data = torch.randn(V, device=device, dtype=dtype)
233+
torch_lm_head_grpo.ref_lin.bias.data = liger_lm_head_grpo.ref_lin.bias.data = torch.randn(
234+
V, device=device, dtype=dtype
235+
)
235236

236237
# Create inputs with shape [B, T, H]
237238
_input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar
@@ -260,15 +261,25 @@ def test_correctness(
260261

261262
# Forward pass with reference model
262263
loss1, aux1 = torch_lm_head_grpo(
263-
input1, selected_token_ids, attention_mask, advantages, ref_input=ref_input, old_per_token_logps=old_per_token_logps
264+
input1,
265+
selected_token_ids,
266+
attention_mask,
267+
advantages,
268+
ref_input=ref_input,
269+
old_per_token_logps=old_per_token_logps,
264270
)
265271
loss2, aux2 = liger_lm_head_grpo(
266-
input2, selected_token_ids, attention_mask, advantages, ref_input=ref_input, old_per_token_logps=old_per_token_logps
272+
input2,
273+
selected_token_ids,
274+
attention_mask,
275+
advantages,
276+
ref_input=ref_input,
277+
old_per_token_logps=old_per_token_logps,
267278
)
268279

269280
# Check losses match
270-
assert loss1 != float('nan')
271-
assert loss2 != float('nan')
281+
assert loss1 != float("nan")
282+
assert loss2 != float("nan")
272283
assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol)
273284

274285
# Check metrics match
@@ -296,6 +307,7 @@ def test_correctness(
296307
rtol=rtol,
297308
)
298309

310+
299311
@pytest.mark.parametrize(
300312
"B, T, H, V",
301313
[
@@ -316,14 +328,29 @@ def test_correctness(
316328
"beta, epsilon_low, epsilon_high, temperature",
317329
[
318330
# Standard settings
319-
(0.1, 0.2, 0.2, 20.0), # set temperature to 20.0 for better numerical stability
331+
(0.1, 0.2, 0.2, 20.0), # set temperature to 20.0 for better numerical stability
320332
(0.0, 0.1, 0.1, 2.0),
321-
]
333+
],
322334
)
323335
@pytest.mark.parametrize("use_ref_model", [True, False])
324336
@pytest.mark.parametrize("old_per_token_logps", [True, False])
325337
def test_functional_correctness(
326-
B, T, H, V, scalar, dtype, atol, rtol, bias, ref_bias, beta, epsilon_low, epsilon_high, temperature, use_ref_model, old_per_token_logps
338+
B,
339+
T,
340+
H,
341+
V,
342+
scalar,
343+
dtype,
344+
atol,
345+
rtol,
346+
bias,
347+
ref_bias,
348+
beta,
349+
epsilon_low,
350+
epsilon_high,
351+
temperature,
352+
use_ref_model,
353+
old_per_token_logps,
327354
):
328355
_input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar
329356
input1 = _input.detach().clone().requires_grad_(True)
@@ -334,7 +361,7 @@ def test_functional_correctness(
334361
weight2 = _weight.detach().clone().requires_grad_(True)
335362

336363
selected_token_ids = torch.randint(0, V, (B, T), device=device)
337-
364+
338365
attention_mask = torch.ones(B, T, device=device)
339366

340367
advantages = torch.rand(B, device=device, dtype=dtype)
@@ -348,7 +375,7 @@ def test_functional_correctness(
348375
bias2 = None
349376

350377
ref_input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar
351-
378+
352379
_ref_weight = torch.randn(V, H, device=device, dtype=dtype) * scalar
353380
ref_weight1 = _ref_weight.detach().clone().requires_grad_(True)
354381
ref_weight2 = _ref_weight.detach().clone().requires_grad_(True)
@@ -367,15 +394,47 @@ def test_functional_correctness(
367394
old_per_token_logps = None
368395

369396
loss1, aux1 = liger_fused_linear_grpo(
370-
input1, weight1, selected_token_ids, attention_mask, advantages, bias1, ref_input, ref_weight1, ref_bias1, old_per_token_logps, beta, epsilon_low, epsilon_high, temperature, True, use_ref_model, 1
397+
input1,
398+
weight1,
399+
selected_token_ids,
400+
attention_mask,
401+
advantages,
402+
bias1,
403+
ref_input,
404+
ref_weight1,
405+
ref_bias1,
406+
old_per_token_logps,
407+
beta,
408+
epsilon_low,
409+
epsilon_high,
410+
temperature,
411+
True,
412+
use_ref_model,
413+
1,
371414
)
372415

373416
loss2, aux2 = LigerFusedLinearGRPOFunction.apply(
374-
input2, weight2, selected_token_ids, attention_mask, advantages, bias2, ref_input, ref_weight2, ref_bias2, old_per_token_logps, beta, epsilon_low, epsilon_high, temperature, True, use_ref_model, 1
417+
input2,
418+
weight2,
419+
selected_token_ids,
420+
attention_mask,
421+
advantages,
422+
bias2,
423+
ref_input,
424+
ref_weight2,
425+
ref_bias2,
426+
old_per_token_logps,
427+
beta,
428+
epsilon_low,
429+
epsilon_high,
430+
temperature,
431+
True,
432+
use_ref_model,
433+
1,
375434
)
376435

377-
assert loss1 != float('nan')
378-
assert loss2 != float('nan')
436+
assert loss1 != float("nan")
437+
assert loss2 != float("nan")
379438
assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol)
380439

381440
# Check metrics match

0 commit comments

Comments
 (0)