Skip to content

Commit 4980c3f

Browse files
kashiflancerts
andauthored
[cross-entropy-loss] Added support for DFT flag (#860)
## Summary Added support for a flag that turns on the DFT cross entropy loss from the paper https://arxiv.org/abs/2508.05629 - Hardware Type: cuda - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence --------- Co-authored-by: Shao Tang <[email protected]>
1 parent 77a4c1a commit 4980c3f

File tree

4 files changed

+272
-1
lines changed

4 files changed

+272
-1
lines changed

src/liger_kernel/ops/fused_linear_cross_entropy.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def fused_linear_cross_entropy_forward(
2626
softcap=None,
2727
return_z_loss=False,
2828
accum_dtype=None,
29+
use_token_scaling=False,
2930
):
3031
assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
3132
device = _input.device
@@ -89,6 +90,23 @@ def fused_linear_cross_entropy_forward(
8990

9091
n_rows = logits_chunk.shape[0]
9192

93+
# Compute predicted probabilities for token scaling if needed
94+
if use_token_scaling:
95+
# Compute softmax probabilities for scaling
96+
# We need to compute this before the cross entropy kernel modifies logits_chunk
97+
logits_for_softmax = logits_chunk.detach().clone() # Detach to avoid gradient flow
98+
if softcap is not None:
99+
logits_for_softmax = softcap * torch.tanh(logits_for_softmax / softcap)
100+
101+
# Compute softmax to get predicted probabilities
102+
probs = torch.softmax(logits_for_softmax, dim=-1)
103+
104+
# Get the predicted probability for each target token
105+
pred_probs = torch.gather(probs, -1, target_chunk.unsqueeze(-1)).squeeze(-1)
106+
107+
# Store the scaling factors
108+
scaling_factors = pred_probs.detach() # Detach to ensure no gradient flow
109+
92110
# unreduced loss
93111
loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size,
94112
z_loss_1d_slice = z_loss_1d[start_idx:end_idx] if return_z_loss else None
@@ -123,11 +141,23 @@ def fused_linear_cross_entropy_forward(
123141
num_warps=32 if not is_hip() else 16,
124142
)
125143

144+
# Apply token scaling if requested
145+
if use_token_scaling:
146+
loss_1d_slice = loss_1d_slice * scaling_factors
147+
if return_z_loss:
148+
z_loss_1d_slice = z_loss_1d_slice * scaling_factors
149+
126150
loss_1d[start_idx:end_idx] = loss_1d_slice
127151
if return_z_loss:
128152
z_loss_1d[start_idx:end_idx] = z_loss_1d_slice
129153
grad_logits_chunk = logits_chunk # chunk_size x V
130154

155+
# Apply token scaling to gradients if requested
156+
if use_token_scaling:
157+
# Expand scaling factors to match gradient dimensions
158+
scaling_factors_expanded = scaling_factors.unsqueeze(-1) # chunk_size x 1
159+
grad_logits_chunk = grad_logits_chunk * scaling_factors_expanded
160+
131161
grad_input[start_idx:end_idx] = grad_logits_chunk @ weight
132162

133163
if grad_weight is not None:
@@ -136,7 +166,7 @@ def fused_linear_cross_entropy_forward(
136166
if bias is not None:
137167
torch.add(
138168
input=grad_bias,
139-
other=logits_chunk.sum(dim=0),
169+
other=grad_logits_chunk.sum(dim=0),
140170
out=grad_bias,
141171
alpha=1.0,
142172
)
@@ -146,6 +176,10 @@ def fused_linear_cross_entropy_forward(
146176
# loss = loss_1d
147177
# z_loss = z_loss_1d if return_z_loss else None
148178

179+
if reduction == "none":
180+
# Return per-token losses
181+
loss = loss_1d
182+
z_loss = z_loss_1d if return_z_loss else None
149183
else:
150184
loss = torch.sum(loss_1d)
151185
z_loss = torch.sum(z_loss_1d) if return_z_loss else None
@@ -221,6 +255,7 @@ def forward(
221255
softcap=None,
222256
return_z_loss: bool = False,
223257
accum_dtype=None,
258+
use_token_scaling: bool = False,
224259
):
225260
"""
226261
Fusing the last linear layer with cross-entropy loss
@@ -241,6 +276,9 @@ def forward(
241276
reduction: reduction to apply
242277
accum_dtype (torch.dtype): the dtype of intermediate result buffers for weight and bias gradient accumulations.
243278
Recommended to set `accum_dtype` to higher precision, e.g. `torch.float32`, if the training is unstable with original dtype. Default: `None`, performing accumulations in original dtype
279+
use_token_scaling (bool): whether to scale each token's loss by its predicted probability (detached).
280+
When True, each token's loss is multiplied by the model's predicted probability for that token's true class.
281+
Default: False.
244282
"""
245283

246284
loss, z_loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
@@ -256,6 +294,7 @@ def forward(
256294
softcap=softcap,
257295
return_z_loss=return_z_loss,
258296
accum_dtype=accum_dtype,
297+
use_token_scaling=use_token_scaling,
259298
)
260299
# downcast to dtype and store for backward
261300
ctx.save_for_backward(
@@ -288,4 +327,5 @@ def backward(ctx, grad_output, grad_output2):
288327
None,
289328
None,
290329
None,
330+
None, # use_token_scaling
291331
)

src/liger_kernel/transformers/functional.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def liger_fused_linear_cross_entropy(
6565
softcap: Optional[float] = None,
6666
return_z_loss: bool = False,
6767
accum_dtype=None,
68+
use_token_scaling: bool = False,
6869
):
6970
loss, z_loss = LigerFusedLinearCrossEntropyFunction.apply(
7071
input,
@@ -79,6 +80,7 @@ def liger_fused_linear_cross_entropy(
7980
softcap,
8081
return_z_loss,
8182
accum_dtype,
83+
use_token_scaling,
8284
)
8385
if not return_z_loss:
8486
return loss

src/liger_kernel/transformers/fused_linear_cross_entropy.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def __init__(
1616
softcap: Optional[float] = None,
1717
return_z_loss: bool = False,
1818
accum_dtype: Optional[torch.dtype] = None,
19+
use_token_scaling: bool = False,
1920
):
2021
super().__init__()
2122
assert (label_smoothing >= 0) and (label_smoothing <= 1), (
@@ -34,6 +35,7 @@ def __init__(
3435
self.softcap = softcap
3536
self.return_z_loss = return_z_loss
3637
self.accum_dtype = accum_dtype
38+
self.use_token_scaling = use_token_scaling
3739

3840
def forward(self, lin_weight, _input, target, bias=None):
3941
loss, z_loss = LigerFusedLinearCrossEntropyFunction.apply(
@@ -49,6 +51,7 @@ def forward(self, lin_weight, _input, target, bias=None):
4951
self.softcap,
5052
self.return_z_loss,
5153
self.accum_dtype,
54+
self.use_token_scaling,
5255
)
5356
if not self.return_z_loss:
5457
return loss

test/transformers/test_fused_linear_cross_entropy.py

Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,3 +352,229 @@ def test_amp(B, T, H, V, bias, cast_dtype, accum_dtype, atol, rtol):
352352
atol=atol,
353353
rtol=rtol,
354354
)
355+
356+
357+
def test_correctness_token_scaling():
358+
"""Test that token scaling produces the correct loss values and gradients."""
359+
B, T, H, V = 2, 4, 8, 16
360+
dtype = torch.float32
361+
362+
# Create inputs
363+
_input = torch.randn(B * T, H, device=device, dtype=dtype, requires_grad=True)
364+
target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long)
365+
366+
# Create weights
367+
weight = torch.randn(V, H, device=device, dtype=dtype)
368+
bias = torch.randn(V, device=device, dtype=dtype)
369+
370+
# Test using functional API with token scaling
371+
loss_scaled = liger_fused_linear_cross_entropy(
372+
input=_input,
373+
weight=weight,
374+
target=target,
375+
bias=bias,
376+
ignore_index=-100,
377+
reduction="none", # Use "none" to get per-token losses
378+
use_token_scaling=True,
379+
)
380+
381+
# Compare with manual implementation
382+
# Compute logits
383+
logits = _input @ weight.t()
384+
if bias is not None:
385+
logits = logits + bias
386+
387+
# Compute standard cross entropy loss per token
388+
ce_loss = torch.nn.functional.cross_entropy(logits, target, ignore_index=-100, reduction="none")
389+
390+
# Compute predicted probabilities for target tokens
391+
pred_probs = torch.softmax(logits, dim=-1).gather(1, target.unsqueeze(-1)).squeeze(-1).detach()
392+
393+
# Scale by predicted probabilities
394+
expected_loss = ce_loss * pred_probs
395+
396+
# Check that losses are close
397+
assert torch.allclose(loss_scaled, expected_loss, atol=1e-4, rtol=1e-4)
398+
399+
# Test gradients
400+
loss_scaled.sum().backward(retain_graph=True)
401+
grad_scaled = _input.grad.clone()
402+
_input.grad.zero_()
403+
404+
expected_loss.sum().backward(retain_graph=True)
405+
grad_expected = _input.grad.clone()
406+
_input.grad.zero_()
407+
408+
# Check that gradients are close
409+
assert torch.allclose(grad_scaled, grad_expected, atol=1e-4, rtol=1e-4)
410+
411+
412+
def test_correctness_token_scaling_consistency():
413+
"""Test that token scaling is consistent between functional and module APIs."""
414+
B, T, H, V = 2, 4, 8, 16
415+
dtype = torch.float32
416+
417+
# Create inputs
418+
_input = torch.randn(B * T, H, device=device, dtype=dtype, requires_grad=True)
419+
target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long)
420+
421+
# Create weights
422+
weight = torch.randn(V, H, device=device, dtype=dtype)
423+
bias = torch.randn(V, device=device, dtype=dtype)
424+
425+
# Test functional API
426+
loss_functional = liger_fused_linear_cross_entropy(
427+
input=_input,
428+
weight=weight,
429+
target=target,
430+
bias=bias,
431+
ignore_index=-100,
432+
reduction="sum",
433+
use_token_scaling=True,
434+
)
435+
436+
# Test module API
437+
ce_loss_module = LigerFusedLinearCrossEntropyLoss(
438+
ignore_index=-100,
439+
reduction="sum",
440+
use_token_scaling=True,
441+
)
442+
443+
loss_module = ce_loss_module(weight, _input, target, bias)
444+
445+
# Check that losses are identical
446+
assert torch.allclose(loss_functional, loss_module, atol=1e-6, rtol=1e-6)
447+
448+
# Test gradients
449+
loss_functional.backward(retain_graph=True)
450+
grad_functional = _input.grad.clone()
451+
_input.grad.zero_()
452+
453+
loss_module.backward(retain_graph=True)
454+
grad_module = _input.grad.clone()
455+
_input.grad.zero_()
456+
457+
# Check that gradients are identical
458+
assert torch.allclose(grad_functional, grad_module, atol=1e-6, rtol=1e-6)
459+
460+
461+
def test_correctness_token_scaling_functional():
462+
"""Test token scaling using the functional API."""
463+
B, T, H, V = 2, 4, 8, 16
464+
dtype = torch.float32
465+
466+
# Create inputs
467+
_input = torch.randn(B * T, H, device=device, dtype=dtype)
468+
x1 = _input.detach().clone().requires_grad_(True)
469+
x2 = _input.detach().clone().requires_grad_(True)
470+
471+
target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long)
472+
473+
# Create weights
474+
weight = torch.randn(V, H, device=device, dtype=dtype)
475+
bias = torch.randn(V, device=device, dtype=dtype)
476+
477+
# Test using functional API with token scaling
478+
y1 = liger_fused_linear_cross_entropy(
479+
input=x1,
480+
weight=weight,
481+
target=target,
482+
bias=bias,
483+
ignore_index=-100,
484+
lse_square_scale=0.0,
485+
label_smoothing=0.0,
486+
reduction="sum", # Use sum for easier verification
487+
softcap=None,
488+
return_z_loss=False,
489+
accum_dtype=None,
490+
use_token_scaling=True,
491+
)
492+
493+
# Compare with manual implementation
494+
# Compute logits
495+
logits = x2 @ weight.t()
496+
if bias is not None:
497+
logits = logits + bias
498+
499+
# Compute softmax probabilities
500+
probs = torch.softmax(logits.detach(), dim=-1) # Detach to avoid gradient flow
501+
502+
# Get predicted probabilities for target tokens
503+
pred_probs = torch.gather(probs, -1, target.unsqueeze(-1)).squeeze(-1)
504+
505+
# Compute standard cross entropy loss
506+
ce_loss = torch.nn.functional.cross_entropy(logits, target, ignore_index=-100, reduction="none")
507+
508+
# Scale by predicted probabilities
509+
scaled_loss = ce_loss * pred_probs
510+
511+
# Sum over all tokens
512+
y2 = scaled_loss.sum()
513+
514+
# Check that losses are close
515+
assert torch.allclose(y1, y2, atol=1e-5, rtol=1e-5)
516+
517+
# Test gradients
518+
y1.backward()
519+
y2.backward()
520+
521+
# Check that gradients are close
522+
assert torch.allclose(x1.grad, x2.grad, atol=1e-5, rtol=1e-5)
523+
524+
525+
def test_correctness_token_scaling_module():
526+
"""Test token scaling using the module API."""
527+
B, T, H, V = 2, 4, 8, 16
528+
dtype = torch.float32
529+
530+
# Create inputs
531+
_input = torch.randn(B * T, H, device=device, dtype=dtype)
532+
x1 = _input.detach().clone().requires_grad_(True)
533+
x2 = _input.detach().clone().requires_grad_(True)
534+
535+
target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long)
536+
537+
# Create module with token scaling
538+
ce_loss = LigerFusedLinearCrossEntropyLoss(
539+
ignore_index=-100,
540+
reduction="sum",
541+
use_token_scaling=True,
542+
)
543+
544+
# Create weights
545+
weight = torch.randn(V, H, device=device, dtype=dtype)
546+
bias = torch.randn(V, device=device, dtype=dtype)
547+
548+
# Test using module API with token scaling
549+
y1 = ce_loss(weight, x1, target, bias)
550+
551+
# Compare with manual implementation
552+
# Compute logits
553+
logits = x2 @ weight.t()
554+
if bias is not None:
555+
logits = logits + bias
556+
557+
# Compute softmax probabilities
558+
probs = torch.softmax(logits.detach(), dim=-1) # Detach to avoid gradient flow
559+
560+
# Get predicted probabilities for target tokens
561+
pred_probs = torch.gather(probs, -1, target.unsqueeze(-1)).squeeze(-1)
562+
563+
# Compute standard cross entropy loss
564+
ce_loss_manual = torch.nn.functional.cross_entropy(logits, target, ignore_index=-100, reduction="none")
565+
566+
# Scale by predicted probabilities
567+
scaled_loss = ce_loss_manual * pred_probs
568+
569+
# Sum over all tokens
570+
y2 = scaled_loss.sum()
571+
572+
# Check that losses are close
573+
assert torch.allclose(y1, y2, atol=1e-5, rtol=1e-5)
574+
575+
# Test gradients
576+
y1.backward()
577+
y2.backward()
578+
579+
# Check that gradients are close
580+
assert torch.allclose(x1.grad, x2.grad, atol=1e-5, rtol=1e-5)

0 commit comments

Comments
 (0)