Skip to content

Commit 81b3554

Browse files
authored
Fix vLLM prompt logprobs calculation (#384)
1 parent 23e9d55 commit 81b3554

File tree

9 files changed

+299
-32
lines changed

9 files changed

+299
-32
lines changed

.github/workflows/sphinx-doc.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ on:
1111

1212
jobs:
1313
pages:
14-
timeout-minutes: 20
14+
timeout-minutes: 30
1515
runs-on: ${{ matrix.os }}
1616
strategy:
1717
matrix:

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ requires-python = ">=3.10,<3.13"
2323
dependencies = [
2424
"verl==0.5.0",
2525
"ray[default]>=2.48.0",
26-
"vllm>=0.9.1,<=0.11.0",
26+
"vllm>=0.10.2,<=0.11.0",
2727
"tensordict",
2828
"wandb",
2929
"omegaconf",

tests/common/vllm_test.py

Lines changed: 91 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import unittest
22

3+
import ray
34
import torch
45
from openai import BadRequestError
56
from parameterized import parameterized_class
@@ -11,6 +12,7 @@
1112
get_model_path,
1213
get_template_config,
1314
)
15+
from trinity.common.constants import ROLLOUT_WEIGHT_SYNC_GROUP_NAME
1416
from trinity.common.models import create_inference_models
1517
from trinity.common.models.model import ModelWrapper
1618
from trinity.common.models.utils import (
@@ -310,8 +312,9 @@ async def test_api(self):
310312
)
311313
self.assertEqual(2, len(response.choices))
312314
self.assertTrue(response.choices[0].logprobs is not None)
313-
self.assertEqual(0, len(response.choices[0].logprobs.content[0].top_logprobs))
314-
self.assertTrue(response.choices[0].logprobs.content[0].logprob < 0)
315+
self.assertEqual(0, len(response.choices[0].logprobs.content[2].top_logprobs))
316+
# here we check the 3rd token logprob, because the first two tokens (`<think>`,`\n` usually have zero logprob)
317+
self.assertTrue(response.choices[0].logprobs.content[2].logprob < 0)
315318
self.assertTrue(hasattr(response, "prompt_token_ids"))
316319
self.assertTrue(len(response.prompt_token_ids) > 0)
317320
self.assertTrue(hasattr(response.choices[0], "token_ids"))
@@ -361,6 +364,89 @@ async def test_api(self):
361364
self.assertEqual(len(self.model_wrapper_no_history.history), 0)
362365

363366

367+
class DummySynchronizer:
368+
def __init__(self):
369+
pass
370+
371+
def do_nothing(self):
372+
pass
373+
374+
375+
class TestLogprobs(RayUnittestBaseAysnc):
376+
def setUp(self):
377+
self.config = get_template_config()
378+
self.config.mode = "explore"
379+
self.config.model.model_path = get_model_path()
380+
self.config.explorer.rollout_model.engine_type = "vllm"
381+
self.config.explorer.rollout_model.engine_num = 1
382+
self.config.explorer.rollout_model.tensor_parallel_size = 1
383+
self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE
384+
self.config.explorer.rollout_model.enable_openai_api = True
385+
386+
self.config.check_and_update()
387+
self.engines, self.auxiliary_engines = create_inference_models(self.config)
388+
self.model_wrapper = ModelWrapper(self.engines[0], engine_type="vllm", enable_history=True)
389+
390+
async def test_logprobs(self):
391+
# use init process group to apply patches
392+
sync = (
393+
ray.remote(DummySynchronizer)
394+
.options(name="synchronizer", namespace=self.config.ray_namespace)
395+
.remote()
396+
)
397+
await sync.__ray_ready__.remote()
398+
await self.model_wrapper.prepare()
399+
master_address, master_port = await self.engines[0].get_available_address.remote()
400+
await self.engines[0].init_process_group.remote(
401+
master_address,
402+
master_port,
403+
world_size=1,
404+
rank_offset=0,
405+
group_name=ROLLOUT_WEIGHT_SYNC_GROUP_NAME,
406+
explorer_name=self.config.explorer.name,
407+
timeout=20,
408+
)
409+
messages = [
410+
{"role": "system", "content": "You are a helpful assistant."},
411+
{"role": "user", "content": "What is your name?"},
412+
]
413+
response_1 = self.model_wrapper.chat(messages, n=1, temperature=1.0, logprobs=True)[0]
414+
response_2 = self.model_wrapper.chat(messages, n=1, temperature=0.8, logprobs=True)[0]
415+
self.assertTrue(response_1.logprobs is not None)
416+
self.assertTrue(len(response_1.logprobs) > 0)
417+
self.assertTrue(response_2.logprobs is not None)
418+
self.assertTrue(len(response_2.logprobs) > 0)
419+
logprobs_1 = self.model_wrapper.logprobs(response_1.tokens.tolist(), temperature=1.0)
420+
logprobs_2 = self.model_wrapper.logprobs(response_1.tokens.tolist(), temperature=0.8)
421+
logprobs_3 = self.model_wrapper.logprobs(response_2.tokens.tolist(), temperature=1.0)
422+
logprobs_4 = self.model_wrapper.logprobs(response_2.tokens.tolist(), temperature=0.8)
423+
self.assertEqual(logprobs_1.shape, logprobs_2.shape)
424+
self.assertEqual(logprobs_3.shape, logprobs_4.shape)
425+
self.assertFalse(torch.allclose(logprobs_1, logprobs_2, rtol=0.4))
426+
self.assertFalse(torch.allclose(logprobs_3, logprobs_4, atol=0.4))
427+
logprobs_1_prompt = logprobs_1[: response_1.prompt_length - 1]
428+
logprobs_2_prompt = logprobs_2[: response_1.prompt_length - 1]
429+
logprobs_3_prompt = logprobs_3[: response_2.prompt_length - 1]
430+
logprobs_4_prompt = logprobs_4[: response_2.prompt_length - 1]
431+
self.assertEqual(logprobs_1_prompt.shape, logprobs_2_prompt.shape)
432+
self.assertFalse(torch.allclose(logprobs_1_prompt, logprobs_2_prompt, rtol=0.4))
433+
self.assertFalse(torch.allclose(logprobs_3_prompt, logprobs_4_prompt, rtol=0.4))
434+
self.assertTrue(torch.allclose(logprobs_1_prompt, logprobs_3_prompt, rtol=0.4))
435+
self.assertTrue(torch.allclose(logprobs_2_prompt, logprobs_4_prompt, rtol=0.4))
436+
logprobs_1_response = logprobs_1[response_1.prompt_length - 1 :]
437+
logprobs_2_response = logprobs_2[response_1.prompt_length - 1 :]
438+
logprobs_3_response = logprobs_3[response_2.prompt_length - 1 :]
439+
logprobs_4_response = logprobs_4[response_2.prompt_length - 1 :]
440+
self.assertEqual(logprobs_1_response.shape, logprobs_2_response.shape)
441+
self.assertEqual(logprobs_3_response.shape, logprobs_4_response.shape)
442+
self.assertEqual(logprobs_1_response.shape, logprobs_2_response.shape)
443+
self.assertEqual(response_1.logprobs.shape, logprobs_1_response.shape)
444+
self.assertTrue(torch.allclose(response_1.logprobs, logprobs_1_response, rtol=0.5))
445+
self.assertFalse(torch.allclose(response_1.logprobs, logprobs_2_response, rtol=0.5))
446+
self.assertTrue(torch.allclose(response_2.logprobs, logprobs_4_response, rtol=0.8))
447+
self.assertFalse(torch.allclose(response_2.logprobs, logprobs_3_response, rtol=0.8))
448+
449+
364450
class TestAsyncAPIServer(RayUnittestBaseAysnc):
365451
def setUp(self):
366452
self.config = get_template_config()
@@ -403,8 +489,9 @@ async def test_api_async(self):
403489
)
404490
self.assertEqual(2, len(response.choices))
405491
self.assertTrue(response.choices[0].logprobs is not None)
406-
self.assertEqual(0, len(response.choices[0].logprobs.content[0].top_logprobs))
407-
self.assertTrue(response.choices[0].logprobs.content[0].logprob < 0)
492+
self.assertEqual(0, len(response.choices[0].logprobs.content[2].top_logprobs))
493+
# here we check the 3rd token logprob, because the first two tokens (`<think>`,`\n` usually have zero logprob)
494+
self.assertTrue(response.choices[0].logprobs.content[2].logprob < 0)
408495
self.assertTrue(hasattr(response, "prompt_token_ids"))
409496
self.assertTrue(len(response.prompt_token_ids) > 0)
410497
self.assertTrue(hasattr(response.choices[0], "token_ids"))

trinity/common/models/model.py

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,16 @@ async def chat(self, messages: List[dict], **kwargs) -> Sequence[Experience]:
3131
"""Generate experiences from a list of history chat messages in async."""
3232
raise NotImplementedError
3333

34-
async def logprobs(self, tokens: List[int]) -> Tensor:
34+
async def logprobs(self, token_ids: List[int], **kwargs) -> Tensor:
3535
"""Generate logprobs for a list of tokens in async."""
3636
raise NotImplementedError
3737

38-
async def convert_messages_to_experience(self, messages: List[dict]) -> Experience:
38+
async def convert_messages_to_experience(
39+
self,
40+
messages: List[dict],
41+
tools: Optional[List[dict]] = None,
42+
temperature: Optional[float] = None,
43+
) -> Experience:
3944
"""Convert a list of messages into an experience in async."""
4045
raise NotImplementedError
4146

@@ -205,21 +210,39 @@ async def chat_mm_async(
205210
) -> List[Experience]:
206211
return await self.model.chat_mm.remote(messages, images=images, videos=videos, **kwargs)
207212

208-
def logprobs(self, tokens: List[int]) -> Tensor:
213+
def logprobs(self, tokens: List[int], temperature: Optional[float] = None) -> Tensor:
209214
"""Calculate the logprobs of the given tokens."""
210-
return ray.get(self.model.logprobs.remote(tokens))
215+
return ray.get(self.model.logprobs.remote(tokens, temperature=temperature))
211216

212-
async def logprobs_async(self, tokens: List[int]) -> Tensor:
217+
async def logprobs_async(
218+
self, tokens: List[int], temperature: Optional[float] = None
219+
) -> Tensor:
213220
"""Calculate the logprobs of the given tokens in async."""
214-
return await self.model.logprobs.remote(tokens)
221+
return await self.model.logprobs.remote(tokens, temperature=temperature)
215222

216-
def convert_messages_to_experience(self, messages: List[dict]) -> Experience:
223+
def convert_messages_to_experience(
224+
self,
225+
messages: List[dict],
226+
tools: Optional[List[dict]] = None,
227+
temperature: Optional[float] = None,
228+
) -> Experience:
217229
"""Convert a list of messages into an experience."""
218-
return ray.get(self.model.convert_messages_to_experience.remote(messages))
230+
return ray.get(
231+
self.model.convert_messages_to_experience.remote(
232+
messages, tools=tools, temperature=temperature
233+
)
234+
)
219235

220-
async def convert_messages_to_experience_async(self, messages: List[dict]) -> Experience:
236+
async def convert_messages_to_experience_async(
237+
self,
238+
messages: List[dict],
239+
tools: Optional[List[dict]] = None,
240+
temperature: Optional[float] = None,
241+
) -> Experience:
221242
"""Convert a list of messages into an experience in async."""
222-
return await self.model.convert_messages_to_experience.remote(messages)
243+
return await self.model.convert_messages_to_experience.remote(
244+
messages, tools=tools, temperature=temperature
245+
)
223246

224247
@property
225248
def model_version(self) -> int:

trinity/common/models/vllm_model.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,13 @@
1616

1717
from trinity.common.config import InferenceModelConfig
1818
from trinity.common.experience import Experience
19-
from trinity.common.models.api.vllm_patch import get_vllm_version
2019
from trinity.common.models.mm_utils import (
2120
build_multi_modal_inputs,
2221
convert_messages_to_mm_format,
2322
)
2423
from trinity.common.models.model import InferenceModel
2524
from trinity.common.models.utils import get_action_mask_method
25+
from trinity.common.models.vllm_patch.api_patch import get_vllm_version
2626
from trinity.utils.log import get_logger
2727

2828

@@ -100,6 +100,7 @@ def __init__(
100100
},
101101
disable_log_stats=True,
102102
enable_lora=config.enable_lora,
103+
logprobs_mode="processed_logprobs",
103104
**config.lora_kwargs,
104105
)
105106
if get_vllm_version() > parse_version("0.10.0"):
@@ -307,25 +308,34 @@ async def generate_mm(
307308
]
308309
return experiences
309310

310-
async def logprobs(
311-
self, token_ids: List[int], lora_request: LoRARequest = None
311+
async def logprobs( # type: ignore [override]
312+
self,
313+
token_ids: List[int],
314+
lora_request: LoRARequest = None,
315+
temperature: Optional[float] = None,
312316
) -> torch.Tensor:
313317
"""Calculate the logprobs of the given tokens in async. Please slice the result carefully
314318
to align with the actual response length.
315319
316320
Args:
317321
token_ids (List[int]): The input token ids (seq_length). Please make sure the length of
318322
it does not exceed `max_model_len - 1`.
323+
lora_request (LoRARequest, optional): The LoRA request. Defaults to None.
324+
temperature (float): The temperature for scaling logits.
319325
320326
Returns:
321327
A tensor of logprobs (seq_length - 1).
322328
"""
329+
temperature = temperature if temperature is not None else self.config.temperature
330+
if temperature is None:
331+
temperature = 1.0
323332
output = await self._generate_internal(
324333
prompt={"prompt_token_ids": token_ids},
325334
lora_request=lora_request,
326335
n=1,
327336
max_tokens=1,
328337
prompt_logprobs=0, # vLLM return `prompt_logprobs + 1` logrpobs for each token
338+
temperature=temperature,
329339
)
330340
return torch.tensor(
331341
[list(logprob_dict.values())[0].logprob for logprob_dict in output.prompt_logprobs[1:]],
@@ -357,6 +367,7 @@ async def convert_messages_to_experience(
357367
self,
358368
messages: List[dict],
359369
tools: Optional[List[dict]] = None,
370+
temperature: Optional[float] = None,
360371
) -> Experience:
361372
"""Convert a list of messages into an experience."""
362373
if self.tokenizer is None:
@@ -370,7 +381,10 @@ async def convert_messages_to_experience(
370381
chat_template=self.chat_template,
371382
enable_thinking=self.enable_thinking,
372383
) # (seq_length, ), (seq_length, )
373-
logprobs = await self.logprobs(token_ids=token_ids.tolist()) # (seq_length - 1,)
384+
temperature = temperature if temperature is not None else self.config.temperature
385+
logprobs = await self.logprobs(
386+
token_ids=token_ids.tolist(), temperature=temperature
387+
) # (seq_length - 1,)
374388
return Experience(
375389
tokens=token_ids,
376390
logprobs=logprobs[prompt_length - 1 :],
@@ -481,7 +495,9 @@ async def run_api_server(self) -> bool:
481495
if self.api_server_host is not None and self.api_server_port is not None:
482496
return True # already running
483497

484-
from trinity.common.models.api.vllm_patch import run_api_server_in_ray_actor
498+
from trinity.common.models.vllm_patch.api_patch import (
499+
run_api_server_in_ray_actor,
500+
)
485501

486502
api_server_host, api_server_port = self.get_available_address()
487503
self.api_server = asyncio.create_task(
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import vllm
2+
from packaging.version import InvalidVersion
3+
from packaging.version import parse as parse_version
4+
5+
6+
def get_vllm_version():
7+
try:
8+
vllm_version = parse_version(vllm.__version__)
9+
except InvalidVersion:
10+
# for self-compiled vllm,
11+
# we cannot parse the version, trait it as the lowest version we support
12+
vllm_version = parse_version("0.8.5")
13+
return vllm_version

trinity/common/models/api/vllm_patch.py renamed to trinity/common/models/vllm_patch/api_patch.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from typing import Optional, Union
1111

1212
import vllm
13-
from packaging.version import InvalidVersion
1413
from packaging.version import parse as parse_version
1514
from pydantic import Field, TypeAdapter
1615
from vllm.entrypoints.launcher import serve_http
@@ -39,6 +38,7 @@
3938
from vllm.transformers_utils.tokenizer import MistralTokenizer
4039
from vllm.utils import FlexibleArgumentParser, set_ulimit
4140

41+
from trinity.common.models.vllm_patch import get_vllm_version
4242
from trinity.utils.log import get_logger
4343

4444

@@ -327,16 +327,6 @@ async def patch_and_serve_http(app, sock, args):
327327
sock.close()
328328

329329

330-
def get_vllm_version():
331-
try:
332-
vllm_version = parse_version(vllm.__version__)
333-
except InvalidVersion:
334-
# for self-compiled vllm,
335-
# we cannot parse the version, trait it as the lowest version we support
336-
vllm_version = parse_version("0.8.5")
337-
return vllm_version
338-
339-
340330
async def run_api_server_in_ray_actor(
341331
async_llm,
342332
host: str,

0 commit comments

Comments
 (0)