[GKD] Buffer Implementation for Distillation Trainer#5137
[GKD] Buffer Implementation for Distillation Trainer#5137cmpatino wants to merge 28 commits intohuggingface:mainfrom
Conversation
Avoid crashing when using DeepSpeed ZeRO-3 and set up the correct values for `weight_hard_loss` and `weight_soft_loss`
KD Buffer Simplification
Add scripts to run GOLD
There was a problem hiding this comment.
Pull request overview
Implements prompt-level rollout buffering and multi-generation support for GOLDTrainer, decoupling generation from optimization to improve throughput (similar to GRPO-style buffering).
Changes:
- Add buffered generation across gradient-accumulation windows, including multi-generation per prompt and vLLM dedup/remapping logic.
- Introduce new config knobs (
num_generations,generation_batch_size) with validation and updated revision handling (student_model_revisionvsmodel_revision). - Update docs and the example training script to reflect the new configuration behavior.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 5 comments.
| File | Description |
|---|---|
trl/experimental/gold/gold_trainer.py |
Adds buffered dataloader strategy + vLLM multi-generation processing and related training-step changes. |
trl/experimental/gold/gold_config.py |
Adds num_generations / generation_batch_size and validates optimizer-window batch partitioning. |
trl/experimental/gold/gold.py |
Aligns model revision handling and teacher init kwargs; updates example wiring. |
docs/source/gold_trainer.md |
Documents new buffering knobs, revision behavior, and last-batch drop warning. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| new_input_ids = torch.cat([prompt_ids, completion_ids_padded], dim=1) | ||
| prompt_lengths = (prompt_ids != pad_token_id).sum(dim=1) | ||
| new_attention_mask, new_labels = self._build_sequence_batch( | ||
| new_input_ids, prompt_lengths, self.processing_class.pad_token_id |
There was a problem hiding this comment.
In _process_completions_to_buffer, completions are padded with pad_token_id (fallbacking to 0 when the tokenizer has no pad id), but _build_sequence_batch is called with self.processing_class.pad_token_id. If pad_token_id was forced to 0 because pad_token_id is None, attention_mask/labels won’t mask those padded tokens, leading to incorrect loss and attention masking. Use the same pad_token_id value for both padding and _build_sequence_batch (or enforce that the tokenizer always has a pad token/id).
| new_input_ids, prompt_lengths, self.processing_class.pad_token_id | |
| new_input_ids, prompt_lengths, pad_token_id |
| prompt_tokenized = self.processing_class( | ||
| prompt_txts, | ||
| return_tensors="pt", | ||
| padding="longest", | ||
| truncation=True if prompt_max_length else False, | ||
| max_length=prompt_max_length, | ||
| add_special_tokens=False, | ||
| ).to(device) |
There was a problem hiding this comment.
prompt_tokenized = self.processing_class(... padding="longest" ...) relies on the tokenizer’s padding_side. If it’s the default "right", prompt_ids will include trailing pad tokens and new_input_ids = cat([prompt_ids, completion_ids_padded]) will insert a pad “gap” between prompt and completion, shifting position ids for the completion. Consider forcing left padding here (e.g., temporarily setting padding_side="left" for this call) or stripping padding before concatenation so completions start immediately after the last prompt token.
| prompt_tokenized = self.processing_class( | |
| prompt_txts, | |
| return_tensors="pt", | |
| padding="longest", | |
| truncation=True if prompt_max_length else False, | |
| max_length=prompt_max_length, | |
| add_special_tokens=False, | |
| ).to(device) | |
| # Ensure padding does not create a gap between prompt and completion tokens. | |
| original_padding_side = getattr(self.processing_class, "padding_side", None) | |
| if original_padding_side is not None: | |
| self.processing_class.padding_side = "left" | |
| try: | |
| prompt_tokenized = self.processing_class( | |
| prompt_txts, | |
| return_tensors="pt", | |
| padding="longest", | |
| truncation=True if prompt_max_length else False, | |
| max_length=prompt_max_length, | |
| add_special_tokens=False, | |
| ) | |
| finally: | |
| if original_padding_side is not None: | |
| self.processing_class.padding_side = original_padding_side | |
| prompt_tokenized = prompt_tokenized.to(device) |
| if prompt_mask is not None: | ||
| prompt_lengths = prompt_mask.sum(dim=1).to(torch.long) | ||
| else: | ||
| if pad_token_id is not None: | ||
| prompt_lengths = (inputs["prompts"] != pad_token_id).sum(dim=1).to(torch.long) | ||
| else: | ||
| prompt_lengths = torch.full( | ||
| (batch_size,), | ||
| inputs["prompts"].shape[1], | ||
| dtype=torch.long, | ||
| device=device, | ||
| ) | ||
|
|
||
| new_input_ids = generated_tokens | ||
| new_attention_mask = torch.ones_like(new_input_ids) | ||
| if pad_token_id is not None: | ||
| new_attention_mask[new_input_ids == pad_token_id] = 0 | ||
|
|
||
| new_labels = torch.full_like(new_input_ids, -100) | ||
| for idx in range(batch_size): | ||
| length = int(prompt_lengths[idx].item()) | ||
| new_labels[idx, length:] = new_input_ids[idx, length:] | ||
|
|
||
| if pad_token_id is not None: | ||
| new_labels[new_input_ids == pad_token_id] = -100 | ||
| new_attention_mask, new_labels = self._build_sequence_batch(new_input_ids, prompt_lengths, pad_token_id) | ||
|
|
There was a problem hiding this comment.
generate_on_policy_outputs builds labels via _build_sequence_batch(new_input_ids, prompt_lengths, ...) where prompt_lengths is computed as prompt_attention_mask.sum(dim=1). Because DataCollatorForChatML left-pads prompts, the completion in generated_outputs.sequences typically starts after the full prompt tensor length (including left padding), not after the number of non-pad tokens. This can cause some prompt tokens to be treated as completion labels. Use the padded prompt length (e.g., inputs["prompts"].shape[1]) for masking the prompt portion of generated_tokens, while keeping the non-pad length only for decoding/logging.
| """ | ||
|
|
||
| import logging | ||
| import os |
There was a problem hiding this comment.
os is imported but not used in this script, which will fail linting / static checks in many setups. Remove the unused import, or use it where intended.
| import os |
| def get_train_dataloader(self): | ||
| """ | ||
| Override Trainer.get_train_dataloader to load one generation batch per optimizer window. | ||
|
|
||
| The base dataloader yields local batches of size | ||
| `per_device_train_batch_size * gradient_accumulation_steps`, then repeats each batch | ||
| `gradient_accumulation_steps` times so Trainer can run accumulation mini-steps without re-sampling prompts. | ||
| """ | ||
| if self.train_dataset is None: | ||
| raise ValueError("Trainer: training requires a train_dataset.") | ||
|
|
||
| train_dataset = self.train_dataset | ||
| data_collator = self.data_collator | ||
| if is_datasets_available() and isinstance(train_dataset, Dataset): | ||
| train_dataset = self._remove_unused_columns(train_dataset, description="training") | ||
| else: | ||
| data_collator = self._get_collator_with_removed_columns(data_collator, description="training") | ||
|
|
||
| dataloader_params = { | ||
| "batch_size": self._train_batch_size * self.args.gradient_accumulation_steps, | ||
| "collate_fn": data_collator, | ||
| "num_workers": self.args.dataloader_num_workers, | ||
| "pin_memory": self.args.dataloader_pin_memory, | ||
| "persistent_workers": self.args.dataloader_persistent_workers, | ||
| } | ||
|
|
||
| if not isinstance(train_dataset, torch.utils.data.IterableDataset): | ||
| dataloader_params["sampler"] = self._get_train_sampler() | ||
| dataloader_params["drop_last"] = self.args.dataloader_drop_last | ||
| dataloader_params["worker_init_fn"] = partial( | ||
| seed_worker, | ||
| num_workers=self.args.dataloader_num_workers, | ||
| rank=self.args.process_index, | ||
| ) | ||
| if self.args.dataloader_num_workers > 0: | ||
| dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor | ||
|
|
||
| base_dataloader = self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) | ||
| return _RepeatEachBatchDataLoader( | ||
| base_dataloader, | ||
| repeat_count=self.args.gradient_accumulation_steps, | ||
| expected_batch_size=self._train_batch_size * self.args.gradient_accumulation_steps, | ||
| ) | ||
|
|
There was a problem hiding this comment.
Buffering introduces substantial new behavior (custom train sampler/dataloader repetition, _fill_buffer on/off-policy slicing, and multi-generation vLLM remapping) but the existing tests/experimental/test_gold_trainer.py suite doesn’t cover these code paths. Adding unit tests for: (1) _RepeatEachBatchDataLoader dropping/len behavior, (2) _fill_buffer producing deterministic slices across grad-accum steps, and (3) num_generations>1 vLLM dedup/remap producing exactly one distinct completion per duplicate prompt would help prevent regressions.
| ################ | ||
| # Model & Tokenizer | ||
| ################ | ||
| if training_args.student_model_revision is None: |
There was a problem hiding this comment.
Why do we need the student_model_revision parameter, can we not just use model_revision ?
There was a problem hiding this comment.
True. I'll support just model_revision
| return teacher_input_ids, teacher_labels, teacher_attention_mask, teacher_prompt_length | ||
|
|
||
|
|
||
| class _RepeatEachBatchDataLoader: |
There was a problem hiding this comment.
Is this the same as the GRPO trainer one? can they be shared?
There was a problem hiding this comment.
The main reason to use _RepeatEachBatchDataLoader is that it doesn't re-collate samples that will be effectively ignored during training, and it handles dropping partial batches.
Both RepeatSampler (the one used in GRPO) and _RepeatEachBatchDataLoader repeat the batch to get the right number of optimization steps from the transformers Trainer. These repeated batches are ignored because the samples are taken from the buffer instead of from the repeated batch.
Looking back, I think the efficiency gains and handling the partial batches are negligible compared to keeping the codebase tidy without extra classes.
Do you agree with dropping the extra class and using RepeatSampler instead?
There was a problem hiding this comment.
If you can get comparable results / speed then I agree to keep the codebase leaner
|
@codex review |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 7e9cb5eb56
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| prompt_lengths = (prompt_ids != pad_token_id).sum(dim=1) | ||
| new_attention_mask, new_labels = self._build_sequence_batch( |
There was a problem hiding this comment.
Mask prompt span by padded width in vLLM buffer path
In _process_completions_to_buffer, prompt_lengths is computed from non-pad token counts and then used to build labels, which is incorrect when the tokenizer pads on the left: shorter prompts will have some prompt tokens treated as completion targets and contribute to loss. This is a regression from the previous vLLM path (which masked the full prompt width via prompt_ids.shape[1]) and can corrupt distillation whenever users run GOLD with padding_side='left' or tokenizers configured that way.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes and found 2 potential issues.
Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.
| completion_ids_padded = torch.stack(padded_completion_ids_list) | ||
|
|
||
| new_input_ids = torch.cat([prompt_ids, completion_ids_padded], dim=1) | ||
| prompt_lengths = (prompt_ids != pad_token_id).sum(dim=1) |
There was a problem hiding this comment.
Prompt length computation wrong for left-padded tokenizers
Medium Severity
In _process_completions_to_buffer, prompt_lengths is computed as (prompt_ids != pad_token_id).sum(dim=1), which counts non-pad tokens rather than reflecting the actual position where completions start in the concatenated tensor. Completions are appended after the full prompt_ids tensor (including padding), so they always start at position prompt_ids.shape[1]. With left-padded tokenizers (e.g. [PAD, PAD, tok1, tok2]), the non-pad count (2) is less than the padded length (4), causing _build_sequence_batch to label actual prompt tokens as completion tokens in the loss. The old vLLM path correctly used the full padded length.
Additional Locations (1)
| ) | ||
| else: | ||
| self._generate_non_vllm_for_slices(slices, on_policy_indices) | ||
| return |
There was a problem hiding this comment.
Silent fallback for unknown vLLM mode removes error
Low Severity
When use_vllm=True but vllm_mode is neither "server" nor "colocate", the else branch silently falls back to _generate_non_vllm_for_slices instead of raising an error. The old code raised ValueError(f"Unknown vllm_mode: {self.vllm_mode}"). A user who sets use_vllm=True with a typo in vllm_mode would silently get slow non-vLLM generation with no indication anything is wrong, after _wake_vllm_if_needed() has already been called.


Implement Buffer for Distillation Trainer (
GOLDTrainer)Implement generation buffering and multi-generation support for GOLDTrainer
Add a prompt-level generation buffer that decouples generation from the
optimization steps. We adopt a buffer similar to GRPO to generate all rollouts for all mini-batches within an optimization step, leveraging parallel inference engines. This means each worker handles a buffer of
per_device_train_batch_size * gradient_accumulation_steps.Buffer Details
We allow multiple rollouts per prompt, following Thinking Machine’s Tinker example. The number of rollouts per prompt is determined by the num_generations parameter. To keep the effective batch size constant, we introduce the generation_batch_size parameter, which controls how many unique prompts we pass to the inference engine. We enforce
generation_batch_size = per_device_train_batch_size * gradient_accumulation_steps // num_generationsto ensure the effective batch size is invariant across setups.Benchmarks
We can replicate Thinking Machine’s results using both non-Liger and Liger losses, achieving a 3x speedup on a setup with 8 training nodes in colocate mode.
Before submitting
Pull Request section?
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
Note
Medium Risk
Changes core training-loop mechanics (dataloader/sampling, buffering, and generation paths), which can affect correctness and throughput across distributed/vLLM/ZeRO setups. Config validation reduces misconfiguration risk but new batch-partitioning constraints and buffering edge cases could surface in varied training pipelines.
Overview
GOLDTrainernow buffers rollouts across gradient-accumulation windows: the train dataloader samples one optimizer-window batch viaRepeatSampler,_prepare_inputsslices it per accumulation step, and on-policy slices trigger a single batched generation pass (vLLM server/colocate ormodel.generate) whose results are written back into the buffered inputs for training.The config surface is updated with
num_generationsandgeneration_batch_size(with validation that they exactly partition the local optimizer-step batch), plus improved handling of student model revisions (defaultingstudent_model_revisionfrommodel_revisionand erroring on conflicts); docs and examples are updated accordingly. Misc fixes include vLLM prompt de-duplication for multi-gen, standardized label/attention building, Liger JSD tweaks (weights + reshaping + ZeRO-3 gather context), and preservingmessageswhen using Liger.Written by Cursor Bugbot for commit f89e77f. This will update automatically on new commits. Configure here.