Skip to content

Commit fa24166

Browse files
authored
[Cross-entropy] get valid predicted probabilities (#864)
## Summary Forgot to check for ignored tokens when calculating probabilities in #860 - Hardware Type: cuda - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence
1 parent a089cd5 commit fa24166

File tree

3 files changed

+54
-2
lines changed

3 files changed

+54
-2
lines changed

src/liger_kernel/ops/fused_linear_cross_entropy.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,21 @@ def fused_linear_cross_entropy_forward(
101101
# Compute softmax to get predicted probabilities
102102
probs = torch.softmax(logits_for_softmax, dim=-1)
103103

104-
# Get the predicted probability for each target token
105-
pred_probs = torch.gather(probs, -1, target_chunk.unsqueeze(-1)).squeeze(-1)
104+
# Get predicted probabilities for token scaling, handling ignored targets
105+
valid_target_mask = target_chunk != ignore_index
106+
valid_targets = target_chunk[valid_target_mask]
107+
108+
if len(valid_targets) > 0:
109+
# Gather probabilities only for valid targets
110+
valid_probs = probs[valid_target_mask]
111+
pred_probs_valid = torch.gather(valid_probs, -1, valid_targets.unsqueeze(-1)).squeeze(-1)
112+
113+
# Create full tensor with zeros for ignored targets
114+
pred_probs = torch.zeros_like(target_chunk, dtype=probs.dtype, device=probs.device)
115+
pred_probs[valid_target_mask] = pred_probs_valid
116+
else:
117+
# All targets are ignored
118+
pred_probs = torch.zeros_like(target_chunk, dtype=probs.dtype, device=probs.device)
106119

107120
# Store the scaling factors
108121
scaling_factors = pred_probs.detach() # Detach to ensure no gradient flow

src/liger_kernel/transformers/model/loss_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def fixed_fused_linear_cross_entropy(
2525
ignore_index=ignore_index,
2626
softcap=final_logit_softcapping,
2727
accum_dtype=accum_dtype,
28+
**kwargs,
2829
)
2930
if reduction == "sum":
3031
loss = loss / num_items_in_batch

test/transformers/test_fused_linear_cross_entropy.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -578,3 +578,41 @@ def test_correctness_token_scaling_module():
578578

579579
# Check that gradients are close
580580
assert torch.allclose(x1.grad, x2.grad, atol=1e-5, rtol=1e-5)
581+
582+
583+
def test_token_scaling_with_ignore_index():
584+
"""Test token scaling when some targets have ignore_index values."""
585+
B, T, H, V = 2, 4, 8, 1000
586+
dtype = torch.float32
587+
588+
# Create inputs
589+
_input = torch.randn(B * T, H, device=device, dtype=dtype, requires_grad=True)
590+
591+
# Create targets with some ignore_index values (-100)
592+
target = torch.tensor([0, 100, -100, 500, -100, 999], device=device, dtype=torch.long)
593+
_input = torch.randn(6, H, device=device, dtype=dtype, requires_grad=True) # Adjust input size
594+
595+
# Create weights
596+
weight = torch.randn(V, H, device=device, dtype=dtype)
597+
bias = torch.randn(V, device=device, dtype=dtype)
598+
599+
# Test using functional API with token scaling
600+
loss_scaled = liger_fused_linear_cross_entropy(
601+
input=_input,
602+
weight=weight,
603+
target=target,
604+
bias=bias,
605+
ignore_index=-100,
606+
reduction="sum",
607+
use_token_scaling=True,
608+
)
609+
610+
# This should not raise any CUDA errors
611+
assert loss_scaled.numel() == 1 # Should return a scalar for sum reduction
612+
assert not torch.isnan(loss_scaled) # Should not be NaN
613+
assert not torch.isinf(loss_scaled) # Should not be infinite
614+
615+
# Test gradients
616+
loss_scaled.backward()
617+
assert _input.grad is not None
618+
assert not torch.isnan(_input.grad).any() # Gradients should not be NaN

0 commit comments

Comments
 (0)