Skip to content

Add chunked LM head for memory-efficient log-prob computation for AsyncGRPOTrainer#5349

Open
AmineDiro wants to merge 10 commits intomainfrom
chunked-lm-head
Open

Add chunked LM head for memory-efficient log-prob computation for AsyncGRPOTrainer#5349
AmineDiro wants to merge 10 commits intomainfrom
chunked-lm-head

Conversation

@AmineDiro
Copy link
Member

@AmineDiro AmineDiro commented Mar 23, 2026

What does this PR do?

  • Add chunk_lm_head.py a custom function that computes per-token log-probs and entropy without materializing the full [N, V] logits tensor, using online logsumexp
  • Add chunk_lm_head_size config parameter to AsyncGRPOConfig to enable chunked mode with a configurable chunk size
  • Patch the model's forward pass to skip prompt tokens via completion_mask to avoid expensive matmuls on non-completion positions
  • Mutually exclusive with use_liger_kernel (both replace the LM head forward pass) !!

Tests

  • Unit tests for forward/backward numerical correctness against reference implementation
  • Tests for bfloat16 gradient accuracy
  • Tests for completion_mask correctness (masked matches unmasked at completion positions, zero at prompt positions)
  • Backward pass tests with completion_mask

Results

Benchmark Results & Script

Benchmark Script

This script uses AsyncGRPOTrainer with a synthetic 8192-token sequence to profile memory usage across different chunk sizes.

"""
Memory profiling script for AsyncGRPOTrainer with a single 8192-token sequence.

Runs three configurations of chunk_lm_head (None, 4096, 8192) and prints
a comparison table of peak memory usage.

Outputs:
  - memory_snapshot_chunk_{none,4096,8192}.pickle  (load in https://pytorch.org/memory_viz)
  - Comparison table of peak allocated/reserved GB and wall time
"""

import gc
import itertools
import os
import pickle
import queue
import tempfile
import time

import torch
from datasets import load_dataset
from transformers import AutoTokenizer

from trl.experimental.async_grpo import AsyncGRPOConfig, AsyncGRPOTrainer
from trl.experimental.async_grpo.async_rollout_worker import RolloutSample


def dummy_reward_func(completions, **kwargs):
    return [float(hash(c[0]["content"]) % 100) / 100.0 for c in completions]


class _StubRolloutWorker:
    """Minimal rollout worker stub for testing the trainer in isolation."""

    def __init__(self, tokenizer, dataset, num_generations: int = 8, samples_per_weight_sync: int = 10):
        self.rollout_buffer = queue.Queue()
        self._samples_per_weight_sync = samples_per_weight_sync
        self._model_version = 0
        self._sample_iter = self._make_sample_iter(tokenizer, dataset, num_generations)

    def _make_sample_iter(self, tokenizer, dataset, num_generations):
        import itertools

        import numpy as np

        for row in itertools.cycle(dataset):
            completions = [
                [{"role": "assistant", "content": f"{row['completion'][0]['content']} {idx}"}]
                for idx in range(num_generations)
            ]
            prompt_completions = [row["prompt"] + completion for completion in completions]
            prompt_ids = tokenizer.apply_chat_template(
                row["prompt"], tokenize=True, add_generation_prompt=True, return_dict=False
            )
            prompt_completion_ids = tokenizer.apply_chat_template(
                prompt_completions, tokenize=True, add_generation_prompt=False, return_dict=False
            )
            rewards = np.array(dummy_reward_func(completions))
            advantages = (rewards - rewards.mean()) / (rewards.std() + 1e-8)
            for idx in range(num_generations):
                completion_ids = prompt_completion_ids[idx][len(prompt_ids) :]
                yield RolloutSample(
                    prompt=row["prompt"],
                    completion=completions[idx],
                    input_ids=prompt_ids + completion_ids,
                    completion_mask=[0] * len(prompt_ids) + [1] * len(completion_ids),
                    old_log_probs=[0.0] * len(prompt_ids) + [-0.5] * len(completion_ids),
                    advantage=float(advantages[idx]),
                    model_version=self._model_version,
                    metrics={"reward": float(rewards[idx]), "reward_std": float(rewards.std())},
                )

    def _fill_queue(self):
        for _ in range(self._samples_per_weight_sync):
            self.rollout_buffer.put(next(self._sample_iter))

    def start(self):
        self._fill_queue()

    def update_model_version(self, version):
        self._model_version = version
        self._fill_queue()

    def stop(self):
        pass

    def pause(self):
        pass

    def resume(self):
        pass

    def send_weights(self, iterator):
        pass


class _LongSeqStubRolloutWorker(_StubRolloutWorker):
    """Stub that emits a single synthetic 8192-token sequence per sample."""

    SEQ_LEN = 8192

    def _make_sample_iter(self, tokenizer, dataset, num_generations):
        vocab_size = tokenizer.vocab_size or 32000
        input_ids = [i % vocab_size for i in range(self.SEQ_LEN)]
        completion_mask = [0] * (self.SEQ_LEN // 2) + [1] * (self.SEQ_LEN // 2)
        old_log_probs = [0.0] * self.SEQ_LEN
        prompt = [{"role": "user", "content": "x"}]
        completion = [{"role": "assistant", "content": "y"}]
        for _ in itertools.cycle([None]):
            yield RolloutSample(
                prompt=prompt,
                completion=completion,
                input_ids=input_ids,
                completion_mask=completion_mask,
                old_log_probs=old_log_probs,
                advantage=1.0,
                model_version=self._model_version,
                metrics={"reward": 1.0, "reward_std": 0.0},
            )


def run_profile(chunk_lm_head_value, model_id, dataset, tokenizer):
    """Run a single profiling pass with the given chunk_lm_head setting.

    Returns a dict with label, peak_alloc_gb, peak_reserved_gb, elapsed_ms.
    """
    label = str(chunk_lm_head_value) if chunk_lm_head_value is not None else "none"
    print(f"\n{'=' * 60}")
    print(f"Running with chunk_lm_head = {chunk_lm_head_value}")
    print(f"{'=' * 60}")

    tmp_dir = tempfile.mkdtemp()
    config_kwargs = dict(
        output_dir=tmp_dir,
        max_steps=1,
        per_device_train_batch_size=1,
        num_generations=1,
        max_completion_length=8192,
        report_to="none",
        learning_rate=1e-6,
        logging_steps=1,
    )
    if chunk_lm_head_value is not None:
        config_kwargs["chunk_lm_head_size"] = chunk_lm_head_value

    args = AsyncGRPOConfig(**config_kwargs)

    trainer = AsyncGRPOTrainer(
        model=model_id,
        reward_funcs=lambda completions, **kw: [1.0] * len(completions),
        args=args,
        train_dataset=dataset,
        rollout_worker=_LongSeqStubRolloutWorker(tokenizer, dataset, num_generations=1),
    )

    torch.cuda.memory._record_memory_history(max_entries=100000)
    torch.cuda.reset_peak_memory_stats()

    t0 = time.perf_counter()
    trainer.train()
    torch.cuda.synchronize()
    elapsed = time.perf_counter() - t0

    peak_alloc_gb = torch.cuda.max_memory_allocated() / 1024**3
    peak_reserved_gb = torch.cuda.max_memory_reserved() / 1024**3

    # Save memory snapshot
    snapshot = torch.cuda.memory._snapshot()
    snapshot_path = f"memory_snapshot_chunk_{label}.pickle"
    with open(snapshot_path, "wb") as f:
        pickle.dump(snapshot, f, protocol=4)
    torch.cuda.memory._record_memory_history(enabled=None)
    print(f"Saved: {os.getcwd()}/{snapshot_path}")
    print(
        f"Peak allocated: {peak_alloc_gb:.2f} GB | Peak reserved: {peak_reserved_gb:.2f} GB | Time: {elapsed * 1000:.1f} ms"
    )

    # Cleanup to free GPU memory before next run
    del trainer
    gc.collect()
    torch.cuda.empty_cache()

    return {
        "label": chunk_lm_head_value if chunk_lm_head_value is not None else "None",
        "peak_alloc_gb": peak_alloc_gb,
        "peak_reserved_gb": peak_reserved_gb,
        "elapsed_ms": elapsed * 1000,
    }


def main():
    model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
    dataset = load_dataset("trl-internal-testing/zen", "conversational_prompt_completion", split="train")
    tokenizer = AutoTokenizer.from_pretrained(model_id)

    configs = [None, 4096, 8192]
    results = []

    for chunk_val in configs:
        result = run_profile(chunk_val, model_id, dataset, tokenizer)
        results.append(result)

    # Print comparison table
    baseline_alloc = results[0]["peak_alloc_gb"]
    baseline_reserved = results[0]["peak_reserved_gb"]

    print(f"\n{'=' * 95}")
    print("COMPARISON TABLE")
    print(f"{'=' * 95}")
    print(
        f"{'chunk_lm_head':>15} | {'Peak Alloc (GB)':>16} | {'Alloc Reduction':>16} | {'Peak Reserved (GB)':>18} | {'Wall Time (ms)':>15}"
    )
    print(f"{'-' * 15}-+-{'-' * 16}-+-{'-' * 16}-+-{'-' * 18}-+-{'-' * 15}")
    for r in results:
        reduction = f"{baseline_alloc / r['peak_alloc_gb']:.2f}x" if r["peak_alloc_gb"] > 0 else "N/A"
        print(
            f"{str(r['label']):>15} | {r['peak_alloc_gb']:>16.2f} | {reduction:>16} | {r['peak_reserved_gb']:>18.2f} | {r['elapsed_ms']:>15.1f}"
        )
    print(f"{'=' * 95}")


if __name__ == "__main__":
    main()

COMPARISON TABLE

chunk_lm_head Peak Alloc (GB) Alloc Reduction Peak Reserved (GB) Wall Time (ms)
None 18.55 1.00x 20.89 808.7
4096 0.42 44.32x 4.67 459.0
8192 0.76 24.34x 4.67 393.0

Without chunked loss
image

After Chunked loss
image


Note

Medium Risk
Medium risk because it monkey-patches model.forward and changes the trainer’s forward/metric computation path based on a new config flag, which could cause model-compatibility or numerical/regression issues. Safeguards include mutual-exclusion with use_liger_kernel, auto-disable for final_logit_softcapping, and extensive forward/backward tests.

Overview
Adds a new experimental chunk_lm_head implementation (trl/experimental/chunk_lm_head.py) that computes per-token log_probs and entropy by streaming the vocabulary in chunks (online logsumexp) and exposes patch_chunked_lm_head() to replace a CausalLM’s forward pass.

Wires this into AsyncGRPOTrainer behind a new AsyncGRPOConfig.chunk_lm_head_size flag: when enabled, training uses the patched forward to get log_probs/entropy (optionally skipping prompt tokens via completion_mask) instead of building full logits; it errors if combined with use_liger_kernel and disables itself for models with final_logit_softcapping.

Adds comprehensive tests validating numerical parity for forward/backward (fp32 and bf16), correct completion_mask behavior, and parity vs multiple real tiny CausalLM models on CUDA.

Written by Cursor Bugbot for commit 773489a. This will update automatically on new commits. Configure here.

Add `_ChunkedLogProbFunction` to compute per-token log-probs and entropy
without materializing full vocabulary logits, reducing peak memory
usage.
Include `chunk_lm_head` config parameter and integrate into AsyncGRPO
trainer. Add comprehensive tests for forward and backward passes
including
bfloat16 support.
This allows selective computation of log probabilities and entropy only
for completion tokens, avoiding expensive matmuls on prompt tokens. The
mask is applied before the chunked forward computation to filter the
flattened hidden states and targets.
Add validation to prevent using both `chunk_lm_head_size`
and `use_liger_kernel` simultaneously, as both optimize
the LM head forward pass. Update help text to document
this incompatibility.
When using fp16 autocast, hidden states are cast to float16 but lm_head
weights are not, causing a dtype mismatch in matrix multiplication. Cast
w_chunk to match last_hidden.dtype.
@AmineDiro AmineDiro marked this pull request as ready for review March 23, 2026 09:06
grad_hidden.add_(grad_logits @ w_chunk.float())
grad_weight[start:end].add_(grad_logits.t() @ hidden.float())

return grad_hidden.to(hidden.dtype), grad_weight.to(weight.dtype), None, None, None
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Backward pass silently ignores entropy gradient contribution

Medium Severity

_ChunkedLogProbFunction.backward receives grad_entropy but never uses it — only grad_logprobs contributes to grad_hidden and grad_weight. This means backpropagating through the entropy output (e.g., for entropy regularization) silently produces zero gradients. The current trainer only uses entropy for logging inside torch.no_grad(), so training is unaffected today, but the autograd function is mathematically incomplete for its second output.

Fix in Cursor Fix in Web

C = end - start
w_chunk = weight[start:end] # [C, H]

torch.mm(hidden, w_chunk.t(), out=mm_buf[:, :C])
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Backward missing dtype cast causes mixed-precision failure

Medium Severity

The forward explicitly casts weight chunks to match hidden dtype via .to(last_hidden.dtype) (with a comment explaining fp16 autocast scenarios), but the backward omits this cast at w_chunk = weight[start:end]. When hidden and weight have different dtypes (e.g., fp16 autocast with bfloat16 weights), torch.mm(hidden, w_chunk.t(), ...) will raise a dtype mismatch RuntimeError.

Additional Locations (1)
Fix in Cursor Fix in Web

@AmineDiro
Copy link
Member Author

@kashif

Benchmark against liger's LigerFusedLinearGRPOLoss.

  • Liger materializes [1, S, V] logits per batch chunk inside chunk_forward. Even though it's transient (freed bygrad_and_value), the peak allocation scales as O(S * V). With V=152064 and S=32768, that's 32768 * 152064 * 4 bytes ~ 18.6 GB just for the float32 log-probs tensor alone.
  • Liger is dramatically faster at every sequence length. At S=32k -> 167ms vs 1654ms (~10x).
  • The fundamental difference: chunking along V trades compute efficiency for memory efficiency. Liger's approach is compute-optimal (large matmuls + torch compiled) but memory scales with sequence length. Ours is memory-optimal but has overhead and small-matmul inefficiency.
median_time_vs_seqlen peak_memory_vs_seqlen

@qgallouedec
Copy link
Member

That's wonderful!

A few comment comments:

  • You can move this function to trl.trainer.utils, as it should be usable in the future for DPO, Reward, RLOO, and GRPO.
  • I recommend not necessarily adding a new argument: use a default sweet spot value, and we’ll add an argument in the future if the community requests it.
  • Add more tests; vary the models and attention implementations, check these as example

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

"the standard full-logits path. Incompatible with `use_liger_kernel` (both replace the LM head "
"forward pass)."
},
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Default chunk_lm_head_size breaks use_liger_kernel users

Medium Severity

chunk_lm_head_size defaults to 8192 (non-None), so existing code using AsyncGRPOConfig(use_liger_kernel=True) will now hit the mutual-exclusion check and raise a ValueError. Users must explicitly add chunk_lm_head_size=None to preserve their previous configuration. This is a breaking default.

Additional Locations (1)
Fix in Cursor Fix in Web

The `chunk_lm_head` module is now a reusable utility that can be shared
across trainers (currently AsyncGRPO). Moving it from `async_grpo/` to
`experimental/` makes this clearer and simplifies imports.

Also adds support for `logit_scale` parameter (used by Cohere2 models)
and checks for `final_logit_softcapping` compatibility before patching.
Copy link

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 1 potential issue.

There are 4 total unresolved issues (including 3 from previous reviews).

Fix All in Cursor

Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

grad_hidden.add_(grad_logits @ w_chunk.float())
grad_weight[start:end].add_(grad_logits.t() @ hidden.float())

return grad_hidden.to(hidden.dtype), grad_weight.to(weight.dtype), None, None, None, None
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Backward pass silently ignores entropy gradient contribution

Medium Severity

The backward method of _ChunkedLogProbFunction receives grad_entropy as a parameter but never uses it — only grad_logprobs (via g) contributes to grad_hidden and grad_weight. This means any gradient flowing through the entropy output is silently dropped. Currently safe because the trainer only reads entropy inside torch.no_grad(), but the function advertises two differentiable outputs and a reviewer already suggested reusing it across DPO, RLOO, and GRPO trainers where entropy regularization in the loss is common.

Fix in Cursor Fix in Web

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants