Skip to content

Commit 1b9ec28

Browse files
authored
Fix n does not take effect in vLLM v1 engine (#36)
1 parent d613940 commit 1b9ec28

File tree

2 files changed

+9
-14
lines changed

2 files changed

+9
-14
lines changed

tests/common/vllm_test.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import os
22
import unittest
33

4-
import ray
54
import torch
65
from transformers import AutoTokenizer
76

@@ -131,25 +130,25 @@ def test_generate(self):
131130

132131
class TestModelWrapperSyncV0(BaseTestModelWrapper, RayUnittestBase):
133132
def setUp(self):
134-
ray.init(ignore_reinit_error=True)
135133
self.config = get_template_config()
136134
self.config.model.model_path = get_model_path()
137135
self.config.explorer.engine_type = "vllm"
138136
self.config.explorer.tensor_parallel_size = 1
139137
self.config.explorer.engine_num = 2
138+
self.config.explorer.repeat_times = 2
140139
self.config.explorer.chat_template = CHAT_TEMPLATE
141140
self.engines = create_rollout_models(self.config)
142141
self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm")
143142

144143

145144
class TestModelWrapperAsyncV0(BaseTestModelWrapper, RayUnittestBase):
146145
def setUp(self):
147-
ray.init(ignore_reinit_error=True)
148146
self.config = get_template_config()
149147
self.config.model.model_path = get_model_path()
150148
self.config.explorer.engine_type = "vllm_async"
151149
self.config.explorer.engine_num = 2
152150
self.config.explorer.tensor_parallel_size = 1
151+
self.config.explorer.repeat_times = 2
153152
self.config.explorer.use_v1 = False
154153
self.config.explorer.chat_template = CHAT_TEMPLATE
155154
self.engines = create_rollout_models(self.config)
@@ -158,7 +157,6 @@ def setUp(self):
158157

159158
class TestModelWrapperAsyncTPV0(BaseTestModelWrapper, RayUnittestBase):
160159
def setUp(self):
161-
ray.init(ignore_reinit_error=True)
162160
self.config = get_template_config()
163161
self.config.model.model_path = get_model_path()
164162
self.config.explorer.engine_type = "vllm_async"
@@ -172,12 +170,12 @@ def setUp(self):
172170

173171
class TestModelWrapperAsyncTPV1(BaseTestModelWrapper, RayUnittestBase):
174172
def setUp(self):
175-
ray.init(ignore_reinit_error=True)
176173
self.config = get_template_config()
177174
self.config.model.model_path = get_model_path()
178175
self.config.explorer.engine_type = "vllm_async"
179176
self.config.explorer.engine_num = 2
180177
self.config.explorer.tensor_parallel_size = 2
178+
self.config.explorer.repeat_times = 2
181179
self.config.explorer.use_v1 = True
182180
self.config.explorer.chat_template = CHAT_TEMPLATE
183181
self.engines = create_rollout_models(self.config)
@@ -186,7 +184,6 @@ def setUp(self):
186184

187185
class TestModelWrapperAsyncV1(BaseTestModelWrapper, RayUnittestBase):
188186
def setUp(self):
189-
ray.init(ignore_reinit_error=True)
190187
self.config = get_template_config()
191188
self.config.model.model_path = get_model_path()
192189
self.config.explorer.engine_type = "vllm_async"

trinity/common/models/vllm_async_model.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import ray
1313
import torch
1414
import vllm
15+
from vllm.sampling_params import RequestOutputKind
1516

1617
from trinity.common.config import Config
1718
from trinity.common.experience import Experience
@@ -61,6 +62,7 @@ def __init__(
6162
truncate_prompt_tokens=config.model.max_prompt_tokens,
6263
skip_special_tokens=True,
6364
include_stop_str_in_output=False,
65+
output_kind=RequestOutputKind.FINAL_ONLY,
6466
logprobs=config.explorer.logprobs,
6567
)
6668
self.request_id = 0
@@ -148,10 +150,8 @@ async def generate_async(self, prompt: str, **kwargs) -> List[Experience]:
148150
Returns:
149151
A list of experiences.
150152
"""
151-
request_id = self.request_id
152-
self.request_id += 1
153153
async with self.semaphore:
154-
output = await self._generate_internal(request_id=request_id, prompt=prompt, **kwargs)
154+
output = await self._generate_internal(prompt=prompt, **kwargs)
155155
experiences = [
156156
Experience(
157157
tokens=torch.cat(
@@ -186,11 +186,8 @@ async def generate_async(self, prompt: str, **kwargs) -> List[Experience]:
186186

187187
async def logprobs_async(self, token_ids: List[int]) -> torch.Tensor:
188188
"""Calculate the logprobs of the given tokens in async."""
189-
request_id = self.request_id
190-
self.request_id += 1
191189
async with self.semaphore:
192190
output = await self._generate_internal(
193-
request_id=request_id,
194191
prompt={"prompt_token_ids": token_ids},
195192
n=1,
196193
max_tokens=1,
@@ -205,10 +202,11 @@ async def logprobs_async(self, token_ids: List[int]) -> torch.Tensor:
205202
dtype=torch.float32,
206203
)
207204

208-
async def _generate_internal(self, request_id: int, prompt: Any, **kwargs) -> Any:
205+
async def _generate_internal(self, prompt: Any, **kwargs) -> Any:
209206
# Send the request to the LLM engine.
207+
self.request_id += 1
210208
stream = self.async_llm.generate(
211-
request_id=str(request_id),
209+
request_id=str(self.request_id),
212210
prompt=prompt,
213211
sampling_params=self._create_sampling_params(**kwargs),
214212
)

0 commit comments

Comments
 (0)