Skip to content

Commit f0c2490

Browse files
authored
Merge branch 'main' into VESPO
2 parents 3c3d318 + 7544c3a commit f0c2490

20 files changed

+433
-73
lines changed

.github/workflows/tests-experimental.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ on:
1010
env:
1111
TQDM_DISABLE: 1
1212
PYTORCH_CUDA_ALLOC_CONF: "expandable_segments:True"
13+
PYTORCH_ALLOC_CONF: "expandable_segments:True"
1314
TRL_EXPERIMENTAL_SILENCE: 1
1415

1516
jobs:

.github/workflows/tests.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ env:
2222
TQDM_DISABLE: 1
2323
CI_SLACK_CHANNEL: ${{ secrets.CI_PUSH_MAIN_CHANNEL }}
2424
PYTORCH_CUDA_ALLOC_CONF: "expandable_segments:True"
25+
PYTORCH_ALLOC_CONF: "expandable_segments:True"
2526

2627
jobs:
2728
check_code_quality:

.github/workflows/tests_transformers_branch.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ env:
1212
TQDM_DISABLE: 1
1313
CI_SLACK_CHANNEL: ${{ secrets.CI_PUSH_MAIN_CHANNEL }}
1414
PYTORCH_CUDA_ALLOC_CONF: "expandable_segments:True"
15+
PYTORCH_ALLOC_CONF: "expandable_segments:True"
1516

1617
jobs:
1718
tests_transformers_branch:

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ dev = [
109109
# kernels
110110
"kernels",
111111
# liger
112-
#"liger-kernel>=0.7.0",
112+
"liger-kernel>=0.7.0",
113113
# peft
114114
"peft>=0.8.0",
115115
# quality

tests/experimental/test_kto_trainer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import multiprocess
1516
import pytest
1617
import torch
1718
from datasets import load_dataset
@@ -98,6 +99,11 @@ def test_kto_trainer_with_ref_model_is_model(self):
9899
)
99100

100101
def test_tokenize_and_process_tokens(self):
102+
# Pytest/CI often starts background threads before tests run. Under Python 3.12+,
103+
# using "fork" in a multi-threaded process emits a DeprecationWarning and may deadlock.
104+
# Force "spawn" to keep this multiprocessing test safe while still exercising `num_proc=2`.
105+
multiprocess.set_start_method("spawn", force=True)
106+
101107
training_args = KTOConfig(
102108
output_dir=self.tmp_dir,
103109
per_device_train_batch_size=2,

tests/experimental/test_nash_md_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from trl.experimental.nash_md import NashMDConfig, NashMDTrainer
2222
from trl.experimental.nash_md.nash_md_trainer import GeometricMixtureWrapper
23-
from trl.models.utils import create_reference_model
23+
from trl.experimental.utils import create_reference_model
2424

2525
from ..testing_utils import TrlTestCase, require_llm_blender, require_peft
2626
from .testing_utils import RandomPairwiseJudge

tests/test_cli.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,8 @@ def test_vllm_serve_config_file(self):
130130
with open(config_path, "w") as f:
131131
yaml.dump({"model": "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"}, f)
132132

133-
with patch("trl.cli.commands.vllm_serve.vllm_serve_main") as mock_serve:
133+
# Patch the actual function that `VllmServeCommand.run` imports as `vllm_serve_main`
134+
with patch("trl.scripts.vllm_serve.main") as mock_serve:
134135
with patch("sys.argv", ["trl", "vllm-serve", "--config", config_path]):
135136
main()
136137

tests/test_grpo_trainer.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -861,6 +861,36 @@ def test_training_beta_non_zero(self):
861861
new_param = trainer.model.get_parameter(n)
862862
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
863863

864+
def test_training_with_pad_to_multiple_of(self):
865+
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
866+
867+
training_args = GRPOConfig(
868+
output_dir=self.tmp_dir,
869+
learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates
870+
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
871+
num_generations=3, # reduce the number of generations to reduce memory usage
872+
max_completion_length=8, # reduce the completion length to reduce memory usage
873+
pad_to_multiple_of=8,
874+
report_to="none",
875+
)
876+
trainer = GRPOTrainer(
877+
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
878+
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
879+
args=training_args,
880+
train_dataset=dataset,
881+
)
882+
883+
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
884+
885+
trainer.train()
886+
887+
assert trainer.state.log_history[-1]["train_loss"] is not None
888+
889+
# Check that the params have changed
890+
for n, param in previous_trainable_params.items():
891+
new_param = trainer.model.get_parameter(n)
892+
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
893+
864894
def test_get_off_policy_mask(self):
865895
"""
866896
Test the logic of off-policy masking:
@@ -1771,6 +1801,43 @@ def reward_func(completions, **kwargs):
17711801
new_param = trainer.model.get_parameter(n)
17721802
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
17731803

1804+
@require_vision
1805+
def test_training_vlm_with_pad_to_multiple_of(self):
1806+
# Models like Gemma3 use other forward keyword arguments like token_type_ids that also need to be padded when
1807+
# using pad_to_multiple_of, so we test that the trainer correctly pads all the necessary inputs in this case.
1808+
dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train")
1809+
1810+
def reward_func(completions, **kwargs):
1811+
"""Reward function that rewards longer completions."""
1812+
return [float(len(completion[0]["content"])) for completion in completions]
1813+
1814+
training_args = GRPOConfig(
1815+
output_dir=self.tmp_dir,
1816+
learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates
1817+
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
1818+
num_generations=3, # reduce the number of generations to reduce memory usage
1819+
max_completion_length=8, # reduce the completion length to reduce memory usage
1820+
pad_to_multiple_of=7,
1821+
report_to="none",
1822+
)
1823+
trainer = GRPOTrainer(
1824+
model="trl-internal-testing/tiny-Gemma3ForConditionalGeneration",
1825+
reward_funcs=reward_func,
1826+
args=training_args,
1827+
train_dataset=dataset,
1828+
)
1829+
1830+
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
1831+
1832+
trainer.train()
1833+
1834+
assert trainer.state.log_history[-1]["train_loss"] is not None
1835+
1836+
# Check that the params have changed
1837+
for n, param in previous_trainable_params.items():
1838+
new_param = trainer.model.get_parameter(n)
1839+
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
1840+
17741841
@pytest.mark.parametrize(
17751842
"model_id",
17761843
[
@@ -2554,6 +2621,47 @@ def test_training_with_liger_grpo_kernel_and_peft(self, model_name):
25542621

25552622
release_memory(model, trainer)
25562623

2624+
@require_liger_kernel
2625+
def test_liger_grpo_kernel_importance_sampling(self):
2626+
model_name = "trl-internal-testing/tiny-LlamaForCausalLM-3.2"
2627+
2628+
training_args = GRPOConfig(
2629+
output_dir=self.tmp_dir,
2630+
per_device_train_batch_size=3,
2631+
num_generations=3,
2632+
use_liger_kernel=True,
2633+
max_completion_length=self.max_length,
2634+
importance_sampling_level="sequence",
2635+
report_to="none",
2636+
logging_strategy="no",
2637+
)
2638+
2639+
model = AutoModelForCausalLM.from_pretrained(model_name, dtype="float32")
2640+
tokenizer = AutoTokenizer.from_pretrained(model_name)
2641+
tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token
2642+
2643+
trainer = GRPOTrainer(
2644+
model=model,
2645+
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
2646+
args=training_args,
2647+
train_dataset=self.train_dataset,
2648+
eval_dataset=self.eval_dataset,
2649+
processing_class=tokenizer,
2650+
)
2651+
from liger_kernel.chunked_loss import LigerFusedLinearGRPOLoss
2652+
2653+
assert isinstance(trainer.liger_grpo_loss, LigerFusedLinearGRPOLoss)
2654+
2655+
previous_trainable_params = {n: param.clone() for n, param in model.named_parameters()}
2656+
2657+
trainer.train()
2658+
2659+
for n, param in previous_trainable_params.items():
2660+
new_param = model.get_parameter(n)
2661+
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
2662+
2663+
release_memory(model, trainer)
2664+
25572665
@pytest.mark.parametrize(
25582666
"model_name",
25592667
[

tests/test_rloo_trainer.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -678,6 +678,36 @@ def test_training_beta_zero(self):
678678
new_param = trainer.model.get_parameter(n)
679679
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
680680

681+
def test_training_with_pad_to_multiple_of(self):
682+
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
683+
684+
training_args = RLOOConfig(
685+
output_dir=self.tmp_dir,
686+
learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates
687+
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
688+
num_generations=3, # reduce the number of generations to reduce memory usage
689+
max_completion_length=8, # reduce the completion length to reduce memory usage
690+
pad_to_multiple_of=8,
691+
report_to="none",
692+
)
693+
trainer = RLOOTrainer(
694+
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
695+
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
696+
args=training_args,
697+
train_dataset=dataset,
698+
)
699+
700+
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
701+
702+
trainer.train()
703+
704+
assert trainer.state.log_history[-1]["train_loss"] is not None
705+
706+
# Check that the params have changed
707+
for n, param in previous_trainable_params.items():
708+
new_param = trainer.model.get_parameter(n)
709+
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
710+
681711
@require_peft
682712
@require_vllm
683713
@pytest.mark.skip(reason="We should add a mock for the vLLM server.")
@@ -1242,6 +1272,43 @@ def reward_func(completions, **kwargs):
12421272
new_param = trainer.model.get_parameter(n)
12431273
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
12441274

1275+
@require_vision
1276+
def test_training_vlm_with_pad_to_multiple_of(self):
1277+
# Models like Gemma3 use other forward keyword arguments like token_type_ids that also need to be padded when
1278+
# using pad_to_multiple_of, so we test that the trainer correctly pads all the necessary inputs in this case.
1279+
dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train")
1280+
1281+
def reward_func(completions, **kwargs):
1282+
"""Reward function that rewards longer completions."""
1283+
return [float(len(completion[0]["content"])) for completion in completions]
1284+
1285+
training_args = RLOOConfig(
1286+
output_dir=self.tmp_dir,
1287+
learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates
1288+
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
1289+
num_generations=3, # reduce the number of generations to reduce memory usage
1290+
max_completion_length=8, # reduce the completion length to reduce memory usage
1291+
pad_to_multiple_of=7,
1292+
report_to="none",
1293+
)
1294+
trainer = RLOOTrainer(
1295+
model="trl-internal-testing/tiny-Gemma3ForConditionalGeneration",
1296+
reward_funcs=reward_func,
1297+
args=training_args,
1298+
train_dataset=dataset,
1299+
)
1300+
1301+
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
1302+
1303+
trainer.train()
1304+
1305+
assert trainer.state.log_history[-1]["train_loss"] is not None
1306+
1307+
# Check that the params have changed
1308+
for n, param in previous_trainable_params.items():
1309+
new_param = trainer.model.get_parameter(n)
1310+
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
1311+
12451312
@pytest.mark.parametrize(
12461313
"model_id",
12471314
[

tests/test_sft_trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -812,7 +812,7 @@ def mock_super_compute_loss(model, inputs, return_outputs=False, num_items_in_ba
812812
"attention_mask": torch.tensor([[1, 1, 1, 1, 1]]),
813813
}
814814

815-
with patch("trl.trainer.sft_trainer.BaseTrainer.compute_loss", side_effect=mock_super_compute_loss):
815+
with patch("transformers.Trainer.compute_loss", side_effect=mock_super_compute_loss):
816816
trainer.compute_loss(trainer.model, inputs)
817817

818818
assert captured["skip_logits"] is True
@@ -846,7 +846,7 @@ def mock_super_compute_loss(model, inputs, return_outputs=False, num_items_in_ba
846846
dummy_outputs = (dummy_loss, torch.randn(1, 5, trainer.model.config.vocab_size))
847847
return (dummy_loss, dummy_outputs)
848848

849-
with patch("trl.trainer.sft_trainer.BaseTrainer.compute_loss", side_effect=mock_super_compute_loss):
849+
with patch("transformers.Trainer.compute_loss", side_effect=mock_super_compute_loss):
850850
trainer.predict(trainer.train_dataset)
851851

852852
assert captured["skip_logits"] is False

0 commit comments

Comments
 (0)