Skip to content
Open
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
719c644
Implement buffer for GOLDTrainer
cmpatino Feb 18, 2026
904378b
Clean up code from KD buffer
cmpatino Feb 18, 2026
6a2ece5
Test scripts for trial run
cmpatino Feb 18, 2026
ee07aec
Apply fixes to the Liger loss setting
cmpatino Feb 23, 2026
a3fd2af
Remove test scripts
cmpatino Feb 25, 2026
b0669d9
Handle config parameters better in gold script
cmpatino Feb 25, 2026
b0c4f3e
Upload provisional SLURM script for GOLD
cmpatino Feb 25, 2026
602e564
Refine logic and comments
cmpatino Feb 26, 2026
c4f9a64
Improve clarity of buffer implementation
cmpatino Feb 28, 2026
111b85e
Add validation for num_generations
cmpatino Mar 2, 2026
022af62
Add clarifying comment to num_generations
cmpatino Mar 2, 2026
33e0a82
Patch issue with ZeRO-3
cmpatino Mar 2, 2026
dbb6e70
Refactor context for ZeRO-3 + Liger
cmpatino Mar 2, 2026
9da54b3
Simplify comments and code logic
cmpatino Mar 2, 2026
1cec9ea
Merge pull request #1 from cmpatino/kd-buffer-fix
cmpatino Mar 2, 2026
4435409
Add scripts to run GOLD
cmpatino Mar 2, 2026
ce41aba
Merge pull request #2 from cmpatino/kd-buffer-fix
cmpatino Mar 2, 2026
c0a857f
Merge branch 'kd-buffering' of github.com:cmpatino/trl into kd-buffering
cmpatino Mar 2, 2026
fa62472
Merge branch 'main' into kd-buffering
cmpatino Mar 2, 2026
31161a0
Refactor to simplify logic
cmpatino Mar 2, 2026
da7ef50
Handle student versioning params
cmpatino Mar 3, 2026
e24e681
Add warning when dropping incomplete batches
cmpatino Mar 3, 2026
8d31b7a
Add clarifying note in docs
cmpatino Mar 3, 2026
1ef205b
Remove SLURM script used for testing
cmpatino Mar 3, 2026
506afc1
Remove reference to wandb
cmpatino Mar 3, 2026
7e9cb5e
Merge branch 'main' into kd-buffering
lewtun Mar 4, 2026
98ec20c
Remove `_RepeatEachBatchDataLoader` to simplify codebase
cmpatino Mar 5, 2026
f89e77f
Merge branch 'kd-buffering' of github.com:cmpatino/trl into kd-buffering
cmpatino Mar 5, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion docs/source/gold_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ messages). Important configuration flags on [`GOLDConfig`] include:
matched/unmatched loss.
* `beta`, `lmbda`, `seq_kd` – inherited from [`experimental.gkd.GKDConfig`], controlling the generalized JSD interpolation and on-policy
sampling ratio.
* `num_generations`, `generation_batch_size` – control buffered rollout generation across gradient accumulation windows.
`generation_batch_size` is the number of unique prompts per worker per optimizer step.
* `student_model_revision` and `model_revision` – if `student_model_revision` is unset, GOLD uses `model_revision`.
If both are set and differ, GOLD raises an error to avoid loading different revisions for training vs generation.

A minimal end-to-end example:

Expand Down Expand Up @@ -79,7 +83,7 @@ train_dataset = load_dataset(
training_args = GOLDConfig(
output_dir="gold-model",
per_device_train_batch_size=1,
teacher_model=teacher_name,
teacher_model_name_or_path=teacher_name,
teacher_tokenizer_name_or_path=teacher_name,
use_uld_loss=True,
uld_use_hybrid_loss=True,
Expand All @@ -95,6 +99,11 @@ trainer = GOLDTrainer(
trainer.train()
```

> [!NOTE]
> GOLD buffers one full optimizer-window generation batch (`per_device_train_batch_size * gradient_accumulation_steps`)
> and reuses it across accumulation steps. If the final batch is undersized, GOLD warns and drops that last batch
> (`Dropping last batch due to unexpected batch size`). Set `dataloader_drop_last=True` to avoid this warning.

### Expected dataset type

GOLD requires a [conversational](dataset_formats#conversational) [language modeling](dataset_formats#language_modeling) dataset, e.g.:
Expand Down
19 changes: 16 additions & 3 deletions trl/experimental/gold/gold.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
"""

import logging
import os
Copy link

Copilot AI Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

os is imported but not used in this script, which will fail linting / static checks in many setups. Remove the unused import, or use it where intended.

Suggested change
import os

Copilot uses AI. Check for mistakes.

from datasets import load_dataset
from transformers import AutoTokenizer, GenerationConfig
Expand Down Expand Up @@ -78,6 +79,19 @@
################
# Model & Tokenizer
################
if training_args.student_model_revision is None:
Copy link
Collaborator

@edbeeching edbeeching Mar 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need the student_model_revision parameter, can we not just use model_revision ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True. I'll support just model_revision

training_args.student_model_revision = model_args.model_revision
elif (
model_args.model_revision is not None
and training_args.student_model_revision != model_args.model_revision
):
raise ValueError(
"Conflicting revisions for student model: "
f"student_model_revision={training_args.student_model_revision!r} and "
f"model_revision={model_args.model_revision!r}. "
"Set only one revision, or set both to the same value."
)

quantization_config = get_quantization_config(model_args)
model_kwargs = dict(
revision=training_args.student_model_revision,
Expand All @@ -93,21 +107,21 @@
if training_args.teacher_tokenizer_name_or_path is None and training_args.use_uld_loss:
training_args.teacher_tokenizer_name_or_path = training_args.teacher_model_name_or_path
teacher_model_kwargs = dict(
revision=model_args.model_revision,
trust_remote_code=model_args.trust_remote_code,
attn_implementation=model_args.attn_implementation,
torch_dtype=model_args.dtype,
use_cache=True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
if training_args.teacher_model_init_kwargs is not None:
teacher_model_kwargs.update(training_args.teacher_model_init_kwargs)
training_args.teacher_model_init_kwargs = teacher_model_kwargs

tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
revision=model_args.model_revision,
trust_remote_code=model_args.trust_remote_code,
padding_side="left",
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
Expand All @@ -120,7 +134,6 @@
################
# Training
################
# Handle eval dataset - check if test split exists, fallback to validation or None
eval_dataset = None
if training_args.eval_strategy != "no":
if script_args.dataset_test_split in dataset:
Expand Down
55 changes: 42 additions & 13 deletions trl/experimental/gold/gold_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from dataclasses import dataclass, field
from typing import Any

Expand Down Expand Up @@ -53,6 +54,11 @@ class GOLDConfig(SFTConfig):
seq_kd (`bool`, *optional*, defaults to `False`):
Seq_kd parameter that controls whether to perform Sequence-Level KD (can be viewed as supervised FT on
teacher-generated output).
num_generations (`int`, *optional*, defaults to `1`):
Number of generations per prompt. Each prompt is repeated this many times in the generation batch.
generation_batch_size (`int` or `None`, *optional*, defaults to `None`):
Number of unique prompts per worker per optimizer step. If `None`, it is computed from
`(per_device_train_batch_size * gradient_accumulation_steps) // num_generations`.
use_uld_loss (`bool`, *optional*, defaults to `False`):
Whether to use Universal Logit Distillation (ULD) loss instead of Generalized Jensen-Shannon Divergence
loss.
Expand Down Expand Up @@ -140,10 +146,10 @@ class GOLDConfig(SFTConfig):
default=128,
metadata={"help": "Maximum number of tokens to generate per completion."},
)
student_model_revision: str = field(
default="main",
student_model_revision: str | None = field(
default=None,
metadata={
"help": "Revision of the student model to use. If not specified, the default revision of the model will be used."
"help": "Revision of the student model to use. If not specified, `model_revision` is used."
},
)
teacher_model_name_or_path: str | None = field(
Expand Down Expand Up @@ -178,10 +184,17 @@ class GOLDConfig(SFTConfig):
"FT on teacher-generated output)."
},
)
steps_per_generation: int | None = field(
num_generations: int = field(
default=1,
metadata={
"help": "Number of generations per prompt. Increasing this will decrease the number of unique prompts per optimization step."
},
)
generation_batch_size: int | None = field(
default=None,
metadata={
"help": "Number of optimization steps per generation. If `None`, it defaults to gradient_accumulation_steps."
"help": "Number of unique prompts per worker per optimizer step. "
"If None, computed from (per_device_train_batch_size * gradient_accumulation_steps) // num_generations."
},
)

Expand Down Expand Up @@ -367,12 +380,6 @@ class GOLDConfig(SFTConfig):
num_completions_to_print: int = field(default=5, metadata={"help": "Number of completions to print."})
overwrite_hub_revision: bool = field(default=False, metadata={"help": "Whether to overwrite the Hub revision."})
push_to_hub_revision: bool = field(default=False, metadata={"help": "Whether to push to a Hub revision/branch."})
trl_project: str = field(
default="smollm3",
metadata={
"help": "The TRL project to use for evaluation. This is used to determine the path to the evaluation script."
},
)

def __post_init__(self):
super().__post_init__()
Expand All @@ -389,8 +396,30 @@ def __post_init__(self):
f"to leave room for the prompt. Consider increasing max_length or reducing max_completion_length."
)

if self.steps_per_generation is None:
self.steps_per_generation = self.gradient_accumulation_steps
if self.num_generations < 1:
raise ValueError(f"num_generations must be at least 1, got {self.num_generations}.")
local_sequence_batch_size = self.per_device_train_batch_size * self.gradient_accumulation_steps
if self.generation_batch_size is None:
self.generation_batch_size = local_sequence_batch_size // self.num_generations
if self.generation_batch_size < 1:
raise ValueError(
"generation_batch_size must be at least 1. "
f"Got generation_batch_size={self.generation_batch_size}."
)
if self.generation_batch_size * self.num_generations != local_sequence_batch_size:
raise ValueError(
"generation_batch_size and num_generations must exactly partition the local optimizer-step batch. "
"Expected generation_batch_size * num_generations == per_device_train_batch_size * "
f"gradient_accumulation_steps, got {self.generation_batch_size} * {self.num_generations} != "
f"{self.per_device_train_batch_size} * {self.gradient_accumulation_steps}."
)
if self.num_generations > 1 and self.lmbda < 1.0:
warnings.warn(
f"num_generations={self.num_generations} with lmbda={self.lmbda} means off-policy batches include "
f"{self.num_generations} copies of each sample; consider lmbda=1.0 when num_generations > 1.",
UserWarning,
stacklevel=2,
)

# Validate ULD parameters
if self.use_uld_loss:
Expand Down
Loading
Loading