Skip to content

[GKD] Buffer Implementation for Distillation Trainer#5137

Open
cmpatino wants to merge 28 commits intohuggingface:mainfrom
cmpatino:kd-buffering
Open

[GKD] Buffer Implementation for Distillation Trainer#5137
cmpatino wants to merge 28 commits intohuggingface:mainfrom
cmpatino:kd-buffering

Conversation

@cmpatino
Copy link
Collaborator

@cmpatino cmpatino commented Feb 20, 2026

Implement Buffer for Distillation Trainer (GOLDTrainer)

Implement generation buffering and multi-generation support for GOLDTrainer

Add a prompt-level generation buffer that decouples generation from the
optimization steps. We adopt a buffer similar to GRPO to generate all rollouts for all mini-batches within an optimization step, leveraging parallel inference engines. This means each worker handles a buffer of per_device_train_batch_size * gradient_accumulation_steps.

Buffer Details

We allow multiple rollouts per prompt, following Thinking Machine’s Tinker example. The number of rollouts per prompt is determined by the num_generations parameter. To keep the effective batch size constant, we introduce the generation_batch_size parameter, which controls how many unique prompts we pass to the inference engine. We enforce generation_batch_size = per_device_train_batch_size * gradient_accumulation_steps // num_generations to ensure the effective batch size is invariant across setups.

Benchmarks

We can replicate Thinking Machine’s results using both non-Liger and Liger losses, achieving a 3x speedup on a setup with 8 training nodes in colocate mode.

comparison_tinker_liger
Phase Tinker (s) TRL (s)
Sampling 329.83 130
Loss 37.96 -
Training 98.69 38
Total 492.28 173

Before submitting

  • Did you read the contributor guideline,
    Pull Request section?
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.


Note

Medium Risk
Changes core training-loop mechanics (dataloader/sampling, buffering, and generation paths), which can affect correctness and throughput across distributed/vLLM/ZeRO setups. Config validation reduces misconfiguration risk but new batch-partitioning constraints and buffering edge cases could surface in varied training pipelines.

Overview
GOLDTrainer now buffers rollouts across gradient-accumulation windows: the train dataloader samples one optimizer-window batch via RepeatSampler, _prepare_inputs slices it per accumulation step, and on-policy slices trigger a single batched generation pass (vLLM server/colocate or model.generate) whose results are written back into the buffered inputs for training.

The config surface is updated with num_generations and generation_batch_size (with validation that they exactly partition the local optimizer-step batch), plus improved handling of student model revisions (defaulting student_model_revision from model_revision and erroring on conflicts); docs and examples are updated accordingly. Misc fixes include vLLM prompt de-duplication for multi-gen, standardized label/attention building, Liger JSD tweaks (weights + reshaping + ZeRO-3 gather context), and preserving messages when using Liger.

Written by Cursor Bugbot for commit f89e77f. This will update automatically on new commits. Configure here.

@cmpatino cmpatino requested a review from qgallouedec March 3, 2026 21:24
@cmpatino cmpatino marked this pull request as ready for review March 3, 2026 21:24
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Implements prompt-level rollout buffering and multi-generation support for GOLDTrainer, decoupling generation from optimization to improve throughput (similar to GRPO-style buffering).

Changes:

  • Add buffered generation across gradient-accumulation windows, including multi-generation per prompt and vLLM dedup/remapping logic.
  • Introduce new config knobs (num_generations, generation_batch_size) with validation and updated revision handling (student_model_revision vs model_revision).
  • Update docs and the example training script to reflect the new configuration behavior.

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 5 comments.

File Description
trl/experimental/gold/gold_trainer.py Adds buffered dataloader strategy + vLLM multi-generation processing and related training-step changes.
trl/experimental/gold/gold_config.py Adds num_generations / generation_batch_size and validates optimizer-window batch partitioning.
trl/experimental/gold/gold.py Aligns model revision handling and teacher init kwargs; updates example wiring.
docs/source/gold_trainer.md Documents new buffering knobs, revision behavior, and last-batch drop warning.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

new_input_ids = torch.cat([prompt_ids, completion_ids_padded], dim=1)
prompt_lengths = (prompt_ids != pad_token_id).sum(dim=1)
new_attention_mask, new_labels = self._build_sequence_batch(
new_input_ids, prompt_lengths, self.processing_class.pad_token_id
Copy link

Copilot AI Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In _process_completions_to_buffer, completions are padded with pad_token_id (fallbacking to 0 when the tokenizer has no pad id), but _build_sequence_batch is called with self.processing_class.pad_token_id. If pad_token_id was forced to 0 because pad_token_id is None, attention_mask/labels won’t mask those padded tokens, leading to incorrect loss and attention masking. Use the same pad_token_id value for both padding and _build_sequence_batch (or enforce that the tokenizer always has a pad token/id).

Suggested change
new_input_ids, prompt_lengths, self.processing_class.pad_token_id
new_input_ids, prompt_lengths, pad_token_id

Copilot uses AI. Check for mistakes.
Comment on lines +1560 to +1567
prompt_tokenized = self.processing_class(
prompt_txts,
return_tensors="pt",
padding="longest",
truncation=True if prompt_max_length else False,
max_length=prompt_max_length,
add_special_tokens=False,
).to(device)
Copy link

Copilot AI Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prompt_tokenized = self.processing_class(... padding="longest" ...) relies on the tokenizer’s padding_side. If it’s the default "right", prompt_ids will include trailing pad tokens and new_input_ids = cat([prompt_ids, completion_ids_padded]) will insert a pad “gap” between prompt and completion, shifting position ids for the completion. Consider forcing left padding here (e.g., temporarily setting padding_side="left" for this call) or stripping padding before concatenation so completions start immediately after the last prompt token.

Suggested change
prompt_tokenized = self.processing_class(
prompt_txts,
return_tensors="pt",
padding="longest",
truncation=True if prompt_max_length else False,
max_length=prompt_max_length,
add_special_tokens=False,
).to(device)
# Ensure padding does not create a gap between prompt and completion tokens.
original_padding_side = getattr(self.processing_class, "padding_side", None)
if original_padding_side is not None:
self.processing_class.padding_side = "left"
try:
prompt_tokenized = self.processing_class(
prompt_txts,
return_tensors="pt",
padding="longest",
truncation=True if prompt_max_length else False,
max_length=prompt_max_length,
add_special_tokens=False,
)
finally:
if original_padding_side is not None:
self.processing_class.padding_side = original_padding_side
prompt_tokenized = prompt_tokenized.to(device)

Copilot uses AI. Check for mistakes.
Comment on lines 2177 to 2192
if prompt_mask is not None:
prompt_lengths = prompt_mask.sum(dim=1).to(torch.long)
else:
if pad_token_id is not None:
prompt_lengths = (inputs["prompts"] != pad_token_id).sum(dim=1).to(torch.long)
else:
prompt_lengths = torch.full(
(batch_size,),
inputs["prompts"].shape[1],
dtype=torch.long,
device=device,
)

new_input_ids = generated_tokens
new_attention_mask = torch.ones_like(new_input_ids)
if pad_token_id is not None:
new_attention_mask[new_input_ids == pad_token_id] = 0

new_labels = torch.full_like(new_input_ids, -100)
for idx in range(batch_size):
length = int(prompt_lengths[idx].item())
new_labels[idx, length:] = new_input_ids[idx, length:]

if pad_token_id is not None:
new_labels[new_input_ids == pad_token_id] = -100
new_attention_mask, new_labels = self._build_sequence_batch(new_input_ids, prompt_lengths, pad_token_id)

Copy link

Copilot AI Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

generate_on_policy_outputs builds labels via _build_sequence_batch(new_input_ids, prompt_lengths, ...) where prompt_lengths is computed as prompt_attention_mask.sum(dim=1). Because DataCollatorForChatML left-pads prompts, the completion in generated_outputs.sequences typically starts after the full prompt tensor length (including left padding), not after the number of non-pad tokens. This can cause some prompt tokens to be treated as completion labels. Use the padded prompt length (e.g., inputs["prompts"].shape[1]) for masking the prompt portion of generated_tokens, while keeping the non-pad length only for decoding/logging.

Copilot uses AI. Check for mistakes.
"""

import logging
import os
Copy link

Copilot AI Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

os is imported but not used in this script, which will fail linting / static checks in many setups. Remove the unused import, or use it where intended.

Suggested change
import os

Copilot uses AI. Check for mistakes.
Comment on lines +1131 to +1174
def get_train_dataloader(self):
"""
Override Trainer.get_train_dataloader to load one generation batch per optimizer window.

The base dataloader yields local batches of size
`per_device_train_batch_size * gradient_accumulation_steps`, then repeats each batch
`gradient_accumulation_steps` times so Trainer can run accumulation mini-steps without re-sampling prompts.
"""
if self.train_dataset is None:
raise ValueError("Trainer: training requires a train_dataset.")

train_dataset = self.train_dataset
data_collator = self.data_collator
if is_datasets_available() and isinstance(train_dataset, Dataset):
train_dataset = self._remove_unused_columns(train_dataset, description="training")
else:
data_collator = self._get_collator_with_removed_columns(data_collator, description="training")

dataloader_params = {
"batch_size": self._train_batch_size * self.args.gradient_accumulation_steps,
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
"persistent_workers": self.args.dataloader_persistent_workers,
}

if not isinstance(train_dataset, torch.utils.data.IterableDataset):
dataloader_params["sampler"] = self._get_train_sampler()
dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["worker_init_fn"] = partial(
seed_worker,
num_workers=self.args.dataloader_num_workers,
rank=self.args.process_index,
)
if self.args.dataloader_num_workers > 0:
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor

base_dataloader = self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
return _RepeatEachBatchDataLoader(
base_dataloader,
repeat_count=self.args.gradient_accumulation_steps,
expected_batch_size=self._train_batch_size * self.args.gradient_accumulation_steps,
)

Copy link

Copilot AI Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Buffering introduces substantial new behavior (custom train sampler/dataloader repetition, _fill_buffer on/off-policy slicing, and multi-generation vLLM remapping) but the existing tests/experimental/test_gold_trainer.py suite doesn’t cover these code paths. Adding unit tests for: (1) _RepeatEachBatchDataLoader dropping/len behavior, (2) _fill_buffer producing deterministic slices across grad-accum steps, and (3) num_generations>1 vLLM dedup/remap producing exactly one distinct completion per duplicate prompt would help prevent regressions.

Copilot uses AI. Check for mistakes.
################
# Model & Tokenizer
################
if training_args.student_model_revision is None:
Copy link
Collaborator

@edbeeching edbeeching Mar 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need the student_model_revision parameter, can we not just use model_revision ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True. I'll support just model_revision

return teacher_input_ids, teacher_labels, teacher_attention_mask, teacher_prompt_length


class _RepeatEachBatchDataLoader:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the same as the GRPO trainer one? can they be shared?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main reason to use _RepeatEachBatchDataLoader is that it doesn't re-collate samples that will be effectively ignored during training, and it handles dropping partial batches.

Both RepeatSampler (the one used in GRPO) and _RepeatEachBatchDataLoader repeat the batch to get the right number of optimization steps from the transformers Trainer. These repeated batches are ignored because the samples are taken from the buffer instead of from the repeated batch.

Looking back, I think the efficiency gains and handling the partial batches are negligible compared to keeping the codebase tidy without extra classes.

Do you agree with dropping the extra class and using RepeatSampler instead?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you can get comparable results / speed then I agree to keep the codebase leaner

@lewtun
Copy link
Member

lewtun commented Mar 4, 2026

@codex review

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 7e9cb5eb56

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines +1600 to +1601
prompt_lengths = (prompt_ids != pad_token_id).sum(dim=1)
new_attention_mask, new_labels = self._build_sequence_batch(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Mask prompt span by padded width in vLLM buffer path

In _process_completions_to_buffer, prompt_lengths is computed from non-pad token counts and then used to build labels, which is incorrect when the tokenizer pads on the left: shorter prompts will have some prompt tokens treated as completion targets and contribute to loss. This is a regression from the previous vLLM path (which masked the full prompt width via prompt_ids.shape[1]) and can corrupt distillation whenever users run GOLD with padding_side='left' or tokenizers configured that way.

Useful? React with 👍 / 👎.

Copy link

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 2 potential issues.

Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

completion_ids_padded = torch.stack(padded_completion_ids_list)

new_input_ids = torch.cat([prompt_ids, completion_ids_padded], dim=1)
prompt_lengths = (prompt_ids != pad_token_id).sum(dim=1)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prompt length computation wrong for left-padded tokenizers

Medium Severity

In _process_completions_to_buffer, prompt_lengths is computed as (prompt_ids != pad_token_id).sum(dim=1), which counts non-pad tokens rather than reflecting the actual position where completions start in the concatenated tensor. Completions are appended after the full prompt_ids tensor (including padding), so they always start at position prompt_ids.shape[1]. With left-padded tokenizers (e.g. [PAD, PAD, tok1, tok2]), the non-pad count (2) is less than the padded length (4), causing _build_sequence_batch to label actual prompt tokens as completion tokens in the loss. The old vLLM path correctly used the full padded length.

Additional Locations (1)

Fix in Cursor Fix in Web

)
else:
self._generate_non_vllm_for_slices(slices, on_policy_indices)
return
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Silent fallback for unknown vLLM mode removes error

Low Severity

When use_vllm=True but vllm_mode is neither "server" nor "colocate", the else branch silently falls back to _generate_non_vllm_for_slices instead of raising an error. The old code raised ValueError(f"Unknown vllm_mode: {self.vllm_mode}"). A user who sets use_vllm=True with a typo in vllm_mode would silently get slow non-vLLM generation with no indication anything is wrong, after _wake_vllm_if_needed() has already been called.

Fix in Cursor Fix in Web

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants