Add chunked LM head for memory-efficient log-prob computation for AsyncGRPOTrainer#5349
Add chunked LM head for memory-efficient log-prob computation for AsyncGRPOTrainer#5349
Conversation
Add `_ChunkedLogProbFunction` to compute per-token log-probs and entropy without materializing full vocabulary logits, reducing peak memory usage. Include `chunk_lm_head` config parameter and integrate into AsyncGRPO trainer. Add comprehensive tests for forward and backward passes including bfloat16 support.
This allows selective computation of log probabilities and entropy only for completion tokens, avoiding expensive matmuls on prompt tokens. The mask is applied before the chunked forward computation to filter the flattened hidden states and targets.
Add validation to prevent using both `chunk_lm_head_size` and `use_liger_kernel` simultaneously, as both optimize the LM head forward pass. Update help text to document this incompatibility.
When using fp16 autocast, hidden states are cast to float16 but lm_head weights are not, causing a dtype mismatch in matrix multiplication. Cast w_chunk to match last_hidden.dtype.
| grad_hidden.add_(grad_logits @ w_chunk.float()) | ||
| grad_weight[start:end].add_(grad_logits.t() @ hidden.float()) | ||
|
|
||
| return grad_hidden.to(hidden.dtype), grad_weight.to(weight.dtype), None, None, None |
There was a problem hiding this comment.
Backward pass silently ignores entropy gradient contribution
Medium Severity
_ChunkedLogProbFunction.backward receives grad_entropy but never uses it — only grad_logprobs contributes to grad_hidden and grad_weight. This means backpropagating through the entropy output (e.g., for entropy regularization) silently produces zero gradients. The current trainer only uses entropy for logging inside torch.no_grad(), so training is unaffected today, but the autograd function is mathematically incomplete for its second output.
| C = end - start | ||
| w_chunk = weight[start:end] # [C, H] | ||
|
|
||
| torch.mm(hidden, w_chunk.t(), out=mm_buf[:, :C]) |
There was a problem hiding this comment.
Backward missing dtype cast causes mixed-precision failure
Medium Severity
The forward explicitly casts weight chunks to match hidden dtype via .to(last_hidden.dtype) (with a comment explaining fp16 autocast scenarios), but the backward omits this cast at w_chunk = weight[start:end]. When hidden and weight have different dtypes (e.g., fp16 autocast with bfloat16 weights), torch.mm(hidden, w_chunk.t(), ...) will raise a dtype mismatch RuntimeError.
Additional Locations (1)
|
That's wonderful! A few comment comments:
|
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
| "the standard full-logits path. Incompatible with `use_liger_kernel` (both replace the LM head " | ||
| "forward pass)." | ||
| }, | ||
| ) |
There was a problem hiding this comment.
Default chunk_lm_head_size breaks use_liger_kernel users
Medium Severity
chunk_lm_head_size defaults to 8192 (non-None), so existing code using AsyncGRPOConfig(use_liger_kernel=True) will now hit the mutual-exclusion check and raise a ValueError. Users must explicitly add chunk_lm_head_size=None to preserve their previous configuration. This is a breaking default.
Additional Locations (1)
The `chunk_lm_head` module is now a reusable utility that can be shared across trainers (currently AsyncGRPO). Moving it from `async_grpo/` to `experimental/` makes this clearer and simplifies imports. Also adds support for `logit_scale` parameter (used by Cohere2 models) and checks for `final_logit_softcapping` compatibility before patching.
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes and found 1 potential issue.
There are 4 total unresolved issues (including 3 from previous reviews).
Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.
| grad_hidden.add_(grad_logits @ w_chunk.float()) | ||
| grad_weight[start:end].add_(grad_logits.t() @ hidden.float()) | ||
|
|
||
| return grad_hidden.to(hidden.dtype), grad_weight.to(weight.dtype), None, None, None, None |
There was a problem hiding this comment.
Backward pass silently ignores entropy gradient contribution
Medium Severity
The backward method of _ChunkedLogProbFunction receives grad_entropy as a parameter but never uses it — only grad_logprobs (via g) contributes to grad_hidden and grad_weight. This means any gradient flowing through the entropy output is silently dropped. Currently safe because the trainer only reads entropy inside torch.no_grad(), but the function advertises two differentiable outputs and a reviewer already suggested reusing it across DPO, RLOO, and GRPO trainers where entropy regularization in the loss is common.




What does this PR do?
chunk_lm_head.pya custom function that computes per-token log-probs and entropy without materializing the full [N, V] logits tensor, using online logsumexpchunk_lm_head_sizeconfig parameter toAsyncGRPOConfigto enable chunked mode with a configurable chunk sizecompletion_maskto avoid expensive matmuls on non-completion positionsuse_liger_kernel(both replace the LM head forward pass) !!Tests
bfloat16gradient accuracycompletion_maskResults
Benchmark Results & Script
Benchmark Script
This script uses
AsyncGRPOTrainerwith a synthetic 8192-token sequence to profile memory usage across different chunk sizes.COMPARISON TABLE
Without chunked loss

After Chunked loss

Note
Medium Risk
Medium risk because it monkey-patches
model.forwardand changes the trainer’s forward/metric computation path based on a new config flag, which could cause model-compatibility or numerical/regression issues. Safeguards include mutual-exclusion withuse_liger_kernel, auto-disable forfinal_logit_softcapping, and extensive forward/backward tests.Overview
Adds a new experimental
chunk_lm_headimplementation (trl/experimental/chunk_lm_head.py) that computes per-tokenlog_probsandentropyby streaming the vocabulary in chunks (online logsumexp) and exposespatch_chunked_lm_head()to replace a CausalLM’s forward pass.Wires this into
AsyncGRPOTrainerbehind a newAsyncGRPOConfig.chunk_lm_head_sizeflag: when enabled, training uses the patched forward to getlog_probs/entropy(optionally skipping prompt tokens viacompletion_mask) instead of building full logits; it errors if combined withuse_liger_kerneland disables itself for models withfinal_logit_softcapping.Adds comprehensive tests validating numerical parity for forward/backward (fp32 and bf16), correct
completion_maskbehavior, and parity vs multiple real tiny CausalLM models on CUDA.Written by Cursor Bugbot for commit 773489a. This will update automatically on new commits. Configure here.