Skip to content

Commit 86e89a2

Browse files
committed
Resolve conflict
2 parents e08d8fe + 11be774 commit 86e89a2

23 files changed

+436
-289
lines changed

docs/source/vllm_integration.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
This document will guide you through the process of using vLLM with TRL for faster generation in online methods like GRPO and Online DPO. We first summarize a tl;dr on how to use vLLM with TRL, and then we will go into the details of how it works under the hood.
44

55
> [!WARNING]
6-
> TRL currently only supports vLLM versions from `0.10.2` to `0.14.1`. Please ensure you have a version in this range installed to avoid compatibility issues.
6+
> TRL currently only supports vLLM versions from `0.10.2` to `0.17.0`. Please ensure you have a version in this range installed to avoid compatibility issues.
77
88
> [!TIP]
99
> The following trainers currently support generation with vLLM:

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ test = [
8383
"pytest"
8484
]
8585
vllm = [
86-
"vllm>=0.10.2,<=0.14.1",
86+
"vllm>=0.10.2,<=0.17.0",
8787
"fastapi",
8888
"pydantic",
8989
"requests",

tests/test_grpo_trainer.py

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -162,17 +162,44 @@ def test_compute_entropy_all_masked(self):
162162
class TestGRPORolloutDispatch:
163163
def _make_trainer(self):
164164
trainer = object.__new__(GRPOTrainer)
165-
trainer.accelerator = SimpleNamespace(device=torch.device("cpu"), is_main_process=True)
165+
trainer.accelerator = SimpleNamespace(
166+
device=torch.device("cpu"),
167+
is_main_process=True,
168+
gather=lambda t: t,
169+
)
166170
trainer.args = SimpleNamespace(report_to=[])
167171
trainer.model = SimpleNamespace(training=True)
168-
trainer.state = SimpleNamespace(global_step=2)
172+
trainer.state = SimpleNamespace(global_step=2, num_input_tokens_seen=0)
169173
trainer._last_loaded_step = 1
170174
trainer.use_vllm = False
171175
trainer.use_transformers_paged = False
172176
trainer.vllm_generation = SimpleNamespace(sync_weights=MagicMock())
177+
trainer.processing_class = SimpleNamespace(
178+
batch_decode=MagicMock(return_value=["decoded"]),
179+
)
180+
trainer.tools = None
181+
trainer.eos_token_id = 2
182+
trainer.pad_token_id = 0
183+
trainer._metrics = {
184+
"train": {
185+
"num_tokens": [],
186+
**{
187+
k: []
188+
for k in [
189+
"completions/mean_length",
190+
"completions/min_length",
191+
"completions/max_length",
192+
"completions/clipped_ratio",
193+
"completions/mean_terminated_length",
194+
"completions/min_terminated_length",
195+
"completions/max_terminated_length",
196+
]
197+
},
198+
}
199+
}
173200
return trainer
174201

175-
def test_generate_single_turn_prefers_rollout_func(self):
202+
def test_generate_prefers_rollout_func(self):
176203
trainer = self._make_trainer()
177204
trainer.rollout_func = MagicMock(
178205
return_value={
@@ -183,33 +210,32 @@ def test_generate_single_turn_prefers_rollout_func(self):
183210
}
184211
)
185212

186-
prompt_ids, completion_ids, logprobs, extra_fields = trainer._generate_single_turn(["prompt"])
213+
result = trainer._generate(["prompt"])
187214

188-
assert prompt_ids == [[1]]
189-
assert completion_ids == [[2]]
190-
assert logprobs == [[-0.1]]
191-
assert extra_fields == {"env_mask": [[1]]}
215+
assert result[0] == [[1]] # prompt_ids
216+
assert result[1] == [[2]] # completion_ids
217+
assert result[2] == [[1]] # tool_mask (from env_mask)
192218
trainer.rollout_func.assert_called_once_with(["prompt"], trainer)
193219

194-
def test_generate_single_turn_rollout_func_syncs_vllm_weights_when_needed(self):
220+
def test_generate_rollout_func_syncs_vllm_weights_when_needed(self):
195221
trainer = self._make_trainer()
196222
trainer.use_vllm = True
197223
trainer.rollout_func = MagicMock(
198224
return_value={"prompt_ids": [[1]], "completion_ids": [[2]], "logprobs": [[0.0]]}
199225
)
200226

201-
trainer._generate_single_turn(["prompt"])
227+
trainer._generate(["prompt"])
202228

203229
trainer.vllm_generation.sync_weights.assert_called_once()
204230
assert trainer._last_loaded_step == trainer.state.global_step
205231
trainer.rollout_func.assert_called_once_with(["prompt"], trainer)
206232

207-
def test_generate_single_turn_rollout_func_raises_when_required_keys_are_missing(self):
233+
def test_generate_rollout_func_raises_when_required_keys_are_missing(self):
208234
trainer = self._make_trainer()
209235
trainer.rollout_func = MagicMock(return_value={"prompt_ids": [[1]], "completion_ids": [[2]]})
210236

211237
with pytest.raises(ValueError, match="rollout_func must return keys"):
212-
trainer._generate_single_turn(["prompt"])
238+
trainer._generate(["prompt"])
213239

214240

215241
class TestGRPOTrainer(TrlTestCase):

tests/test_vllm_client_server.py

Lines changed: 97 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
import pytest
2020
from packaging.version import Version
21-
from transformers import AutoModelForCausalLM, AutoTokenizer
21+
from transformers import AutoModelForCausalLM, AutoProcessor, AutoTokenizer
2222
from transformers.testing_utils import torch_device
2323

2424
from trl.generation.vllm_client import VLLMClient
@@ -31,6 +31,7 @@
3131
kill_process,
3232
require_3_accelerators,
3333
require_torch_multi_accelerator,
34+
require_vision,
3435
require_vllm,
3536
)
3637

@@ -874,3 +875,98 @@ def teardown_class(cls):
874875
# vLLM x pytest (or Popen) seems not to handle process termination well. To avoid zombie processes, we need to
875876
# kill the server process and its children explicitly.
876877
kill_process(cls.server_process)
878+
879+
880+
@pytest.mark.slow
881+
@require_vllm
882+
@require_vision
883+
class TestVLLMClientServerVLM(TrlTestCase):
884+
model_id = "Qwen/Qwen2.5-VL-3B-Instruct"
885+
886+
@classmethod
887+
def setup_class(cls):
888+
# Start the server process
889+
cls.server_process = subprocess.Popen(
890+
["trl", "vllm-serve", "--model", cls.model_id], stdout=subprocess.PIPE, stderr=subprocess.PIPE
891+
)
892+
893+
# Initialize the client (no communicator needed for generation-only tests)
894+
cls.client = VLLMClient(connection_timeout=240, host="localhost")
895+
896+
def test_generate_with_token_ids_and_image(self):
897+
from PIL import Image
898+
899+
processor = AutoProcessor.from_pretrained(self.model_id)
900+
image1 = Image.new("RGB", (64, 64), color="red")
901+
image2 = Image.new("RGB", (64, 64), color="blue")
902+
image3 = Image.new("RGB", (64, 64), color="green")
903+
messages = [
904+
[
905+
{
906+
"role": "user",
907+
"content": [
908+
{"type": "image", "image": image1},
909+
{"type": "image", "image": image2},
910+
{"type": "text", "text": "What are the differences between these two images?"},
911+
],
912+
}
913+
],
914+
[
915+
{
916+
"role": "user",
917+
"content": [
918+
{"type": "image", "image": image3},
919+
{"type": "text", "text": "What is the color of this image?"},
920+
],
921+
}
922+
],
923+
]
924+
prompt_token_ids = processor.apply_chat_template(
925+
conversation=messages, tokenize=True, add_generation_prompt=True
926+
)
927+
outputs = self.client.generate(prompt_token_ids, images=[[image1, image2], [image3]], max_tokens=64)
928+
prompt_ids = outputs["prompt_ids"]
929+
completion_ids = outputs["completion_ids"]
930+
931+
assert len(prompt_ids) == 2
932+
assert len(completion_ids) == 2
933+
assert all(isinstance(tok, int) for tok in prompt_ids[0])
934+
assert all(isinstance(tok, int) for tok in completion_ids[0])
935+
936+
def test_generate_with_token_ids_mixed_images(self):
937+
"""Test a batch where one prompt has an image and the other does not."""
938+
from PIL import Image
939+
940+
processor = AutoProcessor.from_pretrained(self.model_id)
941+
image = Image.new("RGB", (64, 64), color="red")
942+
messages = [
943+
[
944+
{
945+
"role": "user",
946+
"content": [{"type": "image", "image": image}, {"type": "text", "text": "Describe this image."}],
947+
}
948+
],
949+
[
950+
{
951+
"role": "user",
952+
"content": [{"type": "text", "text": "What is 1+1?"}],
953+
}
954+
],
955+
]
956+
prompt_token_ids = processor.apply_chat_template(
957+
conversation=messages, tokenize=True, add_generation_prompt=True
958+
)
959+
outputs = self.client.generate(prompt_token_ids, images=[[image], None], max_tokens=64)
960+
prompt_ids = outputs["prompt_ids"]
961+
completion_ids = outputs["completion_ids"]
962+
963+
assert len(prompt_ids) == 2
964+
assert len(completion_ids) == 2
965+
assert all(isinstance(tok, int) for tok in prompt_ids[0])
966+
assert all(isinstance(tok, int) for tok in prompt_ids[1])
967+
assert all(isinstance(tok, int) for tok in completion_ids[0])
968+
assert all(isinstance(tok, int) for tok in completion_ids[1])
969+
970+
@classmethod
971+
def teardown_class(cls):
972+
kill_process(cls.server_process)

trl/_compat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def _patch_vllm_disabled_tqdm() -> None:
8989
9090
- Bug introduced in https://github.com/vllm-project/vllm/pull/52
9191
- Fixed in https://github.com/vllm-project/vllm/pull/28471 (released in v0.11.1)
92-
- Since TRL currently supports vLLM v0.10.2-0.14.1, we patch it here
92+
- Since TRL currently supports vLLM v0.10.2-0.17.0, we patch it here
9393
- This can be removed when TRL requires vLLM>=0.11.1
9494
"""
9595
if _is_package_version_below("vllm", "0.11.1"):

trl/experimental/cpo/cpo_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -481,7 +481,7 @@ def tokenize_row(self, feature, model: PreTrainedModel | nn.Module | None = None
481481
# and length only differs by 1 at most
482482
num_diff_tokens = sum(
483483
a != b
484-
for a, b in zip(chosen_tokens["prompt_input_ids"], rejected_tokens["prompt_input_ids"], strict=True)
484+
for a, b in zip(chosen_tokens["prompt_input_ids"], rejected_tokens["prompt_input_ids"], strict=False)
485485
)
486486
num_diff_len = abs(chosen_prompt_len_input_ids - rejected_prompt_len_input_ids)
487487
if num_diff_tokens > 1 or num_diff_len > 1:

trl/experimental/gkd/gkd_config.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515
from dataclasses import dataclass, field
1616
from typing import Any
1717

18-
from transformers import TrainingArguments
19-
2018
from ...trainer.sft_config import SFTConfig
2119

2220

@@ -42,7 +40,7 @@ class GKDConfig(SFTConfig):
4240
teacher_model_name_or_path (`str`, *optional*):
4341
Model name or path of the teacher model. If `None`, the teacher model will be the same as the model being
4442
trained.
45-
teacher_model_init_kwargs (`dict[str, Any]]`, *optional*):
43+
teacher_model_init_kwargs (`dict[str, Any]`, *optional*):
4644
Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the teacher model
4745
from a string.
4846
disable_dropout (`bool`, *optional*, defaults to `True`):
@@ -52,7 +50,7 @@ class GKDConfig(SFTConfig):
5250
teacher-generated output).
5351
"""
5452

55-
_VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["teacher_model_init_kwargs"]
53+
_VALID_DICT_FIELDS = SFTConfig._VALID_DICT_FIELDS + ["teacher_model_init_kwargs"]
5654

5755
temperature: float = field(
5856
default=0.9,
@@ -84,7 +82,7 @@ class GKDConfig(SFTConfig):
8482
"model being trained."
8583
},
8684
)
87-
teacher_model_init_kwargs: dict[str, Any] | None = field(
85+
teacher_model_init_kwargs: dict[str, Any] | str | None = field(
8886
default=None,
8987
metadata={
9088
"help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the "

trl/experimental/gold/gold_config.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515
from dataclasses import dataclass, field
1616
from typing import Any
1717

18-
from transformers import TrainingArguments
19-
2018
from ...trainer.sft_config import SFTConfig
2119

2220

@@ -39,13 +37,13 @@ class GOLDConfig(SFTConfig):
3937
beta is `0.0`, the loss is the KL divergence. When beta is `1.0`, the loss is the Inverse KL Divergence.
4038
max_completion_length (`int`, *optional*, defaults to `128`):
4139
Maximum number of tokens to generate per completion.
42-
teacher_model_name_or_path (`str` or `None`, *optional*, defaults to `None`):
40+
teacher_model_name_or_path (`str`, *optional*):
4341
Model name or path of the teacher model. If `None`, the teacher model will be the same as the model being
4442
trained.
45-
teacher_model_init_kwargs (`dict[str, Any]]` or `None`, *optional*, defaults to `None`):
43+
teacher_model_init_kwargs (`dict[str, Any]`, *optional*):
4644
Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the teacher model
4745
from a string.
48-
teacher_tokenizer_name_or_path (`str` or `None`, *optional*, defaults to `None`):
46+
teacher_tokenizer_name_or_path (`str`, *optional*):
4947
Tokenizer name or path for the teacher model. If None when using ULD loss, will use the same tokenizer as
5048
the student model (not recommended for cross-tokenizer distillation).
5149
disable_dropout (`bool`, *optional*, defaults to `True`):
@@ -84,7 +82,7 @@ class GOLDConfig(SFTConfig):
8482
to set this to a low value if the student and teacher models share the same GPU.
8583
vllm_tensor_parallel_size (`int`, *optional*, defaults to `1`):
8684
Tensor parallel size for the colocated student vLLM engine (if `vllm_mode="colocate"`).
87-
vllm_structured_outputs_regex (`str` or `None`, *optional*, defaults to `None`):
85+
vllm_structured_outputs_regex (`str`, *optional*):
8886
Regex for vLLM structured outputs for the student model.
8987
vllm_sync_frequency (`int`, *optional*, defaults to `1`):
9088
Frequency (in training steps) to synchronize student model weights to vLLM engine. Set to 1 to sync after
@@ -94,7 +92,7 @@ class GOLDConfig(SFTConfig):
9492
low, but waking the engine adds host–device transfer latency.
9593
"""
9694

97-
_VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["teacher_model_init_kwargs"]
95+
_VALID_DICT_FIELDS = SFTConfig._VALID_DICT_FIELDS + ["teacher_model_init_kwargs"]
9896

9997
# Parameters whose default values are overridden from TrainingArguments
10098
learning_rate: float = field(
@@ -153,7 +151,7 @@ class GOLDConfig(SFTConfig):
153151
"model being trained."
154152
},
155153
)
156-
teacher_model_init_kwargs: dict[str, Any] | None = field(
154+
teacher_model_init_kwargs: dict[str, Any] | str | None = field(
157155
default=None,
158156
metadata={
159157
"help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the "

trl/experimental/minillm/minillm_config.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class MiniLLMConfig(GRPOConfig):
2929
arguments, please refer to the [`~transformers.TrainingArguments`] and [`GRPOConfig`] documentation.
3030
3131
Args:
32-
teacher_model_init_kwargs (`dict[str, Any]]`, *optional*):
32+
teacher_model_init_kwargs (`dict[str, Any]`, *optional*):
3333
Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the teacher model
3434
from a string.
3535
disable_dropout (`bool`, *optional*, defaults to `True`):
@@ -47,7 +47,9 @@ class MiniLLMConfig(GRPOConfig):
4747
Whether to apply length normalization to the rewards.
4848
"""
4949

50-
teacher_model_init_kwargs: dict[str, Any] | None = field(
50+
_VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["teacher_model_init_kwargs"]
51+
52+
teacher_model_init_kwargs: dict[str, Any] | str | None = field(
5153
default=None,
5254
metadata={
5355
"help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the "

trl/experimental/online_dpo/online_dpo_trainer.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,11 @@
9696

9797
logger = logging.get_logger(__name__)
9898

99-
# What we call a reward function is a callable that takes a list of prompts and completions and returns a list of
100-
# rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model.
101-
RewardFunc = str | PreTrainedModel | Callable[[list, list], list[float]]
99+
# A reward function can be a string, interpreted as a model ID and loaded as a pretrained model, a pretrained model, or
100+
# a callable that returns a list of floats (the rewards). The callable receives prompts, completions, and additional
101+
# arguments from the trainer (refer to the trainer's source for details). To ensure forward compatibility, it should
102+
# accept **kwargs.
103+
RewardFunc = str | PreTrainedModel | Callable[..., list[float]]
102104

103105

104106
class OnlineDPOTrainer(_BaseTrainer):
@@ -750,7 +752,9 @@ def _generate_vllm_server(self, prompts, images=None):
750752
# prompt individually.
751753
ordered_set_of_prompts = all_prompts[:: self.num_generations]
752754
if has_images:
753-
ordered_set_of_images = all_images[:: self.num_generations]
755+
ordered_set_of_images = [
756+
[img] if img is not None else None for img in all_images[:: self.num_generations]
757+
]
754758
else:
755759
ordered_set_of_images = None
756760
completion_ids = self.vllm_client.generate(

0 commit comments

Comments
 (0)