Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/sphinx-doc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ on:

jobs:
pages:
timeout-minutes: 20
timeout-minutes: 30
runs-on: ${{ matrix.os }}
strategy:
matrix:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ requires-python = ">=3.10,<3.13"
dependencies = [
"verl==0.5.0",
"ray[default]>=2.48.0",
"vllm>=0.9.1,<=0.11.0",
"vllm>=0.10.2,<=0.11.0",
"tensordict",
"wandb",
"omegaconf",
Expand Down
95 changes: 91 additions & 4 deletions tests/common/vllm_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import unittest

import ray
import torch
from openai import BadRequestError
from parameterized import parameterized_class
Expand All @@ -11,6 +12,7 @@
get_model_path,
get_template_config,
)
from trinity.common.constants import ROLLOUT_WEIGHT_SYNC_GROUP_NAME
from trinity.common.models import create_inference_models
from trinity.common.models.model import ModelWrapper
from trinity.common.models.utils import (
Expand Down Expand Up @@ -310,8 +312,9 @@ async def test_api(self):
)
self.assertEqual(2, len(response.choices))
self.assertTrue(response.choices[0].logprobs is not None)
self.assertEqual(0, len(response.choices[0].logprobs.content[0].top_logprobs))
self.assertTrue(response.choices[0].logprobs.content[0].logprob < 0)
self.assertEqual(0, len(response.choices[0].logprobs.content[2].top_logprobs))
# here we check the 3rd token logprob, because the first two tokens (`<think>`,`\n` usually have zero logprob)
self.assertTrue(response.choices[0].logprobs.content[2].logprob < 0)
self.assertTrue(hasattr(response, "prompt_token_ids"))
self.assertTrue(len(response.prompt_token_ids) > 0)
self.assertTrue(hasattr(response.choices[0], "token_ids"))
Expand Down Expand Up @@ -361,6 +364,89 @@ async def test_api(self):
self.assertEqual(len(self.model_wrapper_no_history.history), 0)


class DummySynchronizer:
def __init__(self):
pass

def do_nothing(self):
pass


class TestLogprobs(RayUnittestBaseAysnc):
def setUp(self):
self.config = get_template_config()
self.config.mode = "explore"
self.config.model.model_path = get_model_path()
self.config.explorer.rollout_model.engine_type = "vllm"
self.config.explorer.rollout_model.engine_num = 1
self.config.explorer.rollout_model.tensor_parallel_size = 1
self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE
self.config.explorer.rollout_model.enable_openai_api = True

self.config.check_and_update()
self.engines, self.auxiliary_engines = create_inference_models(self.config)
self.model_wrapper = ModelWrapper(self.engines[0], engine_type="vllm", enable_history=True)

async def test_logprobs(self):
# use init process group to apply patches
sync = (
ray.remote(DummySynchronizer)
.options(name="synchronizer", namespace=self.config.ray_namespace)
.remote()
)
await sync.__ray_ready__.remote()
await self.model_wrapper.prepare()
master_address, master_port = await self.engines[0].get_available_address.remote()
await self.engines[0].init_process_group.remote(
master_address,
master_port,
world_size=1,
rank_offset=0,
group_name=ROLLOUT_WEIGHT_SYNC_GROUP_NAME,
explorer_name=self.config.explorer.name,
timeout=20,
)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is your name?"},
]
response_1 = self.model_wrapper.chat(messages, n=1, temperature=1.0, logprobs=True)[0]
response_2 = self.model_wrapper.chat(messages, n=1, temperature=0.8, logprobs=True)[0]
self.assertTrue(response_1.logprobs is not None)
self.assertTrue(len(response_1.logprobs) > 0)
self.assertTrue(response_2.logprobs is not None)
self.assertTrue(len(response_2.logprobs) > 0)
logprobs_1 = self.model_wrapper.logprobs(response_1.tokens.tolist(), temperature=1.0)
logprobs_2 = self.model_wrapper.logprobs(response_1.tokens.tolist(), temperature=0.8)
logprobs_3 = self.model_wrapper.logprobs(response_2.tokens.tolist(), temperature=1.0)
logprobs_4 = self.model_wrapper.logprobs(response_2.tokens.tolist(), temperature=0.8)
self.assertEqual(logprobs_1.shape, logprobs_2.shape)
self.assertEqual(logprobs_3.shape, logprobs_4.shape)
self.assertFalse(torch.allclose(logprobs_1, logprobs_2, rtol=0.4))
self.assertFalse(torch.allclose(logprobs_3, logprobs_4, atol=0.4))
logprobs_1_prompt = logprobs_1[: response_1.prompt_length - 1]
logprobs_2_prompt = logprobs_2[: response_1.prompt_length - 1]
logprobs_3_prompt = logprobs_3[: response_2.prompt_length - 1]
logprobs_4_prompt = logprobs_4[: response_2.prompt_length - 1]
self.assertEqual(logprobs_1_prompt.shape, logprobs_2_prompt.shape)
self.assertFalse(torch.allclose(logprobs_1_prompt, logprobs_2_prompt, rtol=0.4))
self.assertFalse(torch.allclose(logprobs_3_prompt, logprobs_4_prompt, rtol=0.4))
self.assertTrue(torch.allclose(logprobs_1_prompt, logprobs_3_prompt, rtol=0.4))
self.assertTrue(torch.allclose(logprobs_2_prompt, logprobs_4_prompt, rtol=0.4))
logprobs_1_response = logprobs_1[response_1.prompt_length - 1 :]
logprobs_2_response = logprobs_2[response_1.prompt_length - 1 :]
logprobs_3_response = logprobs_3[response_2.prompt_length - 1 :]
logprobs_4_response = logprobs_4[response_2.prompt_length - 1 :]
self.assertEqual(logprobs_1_response.shape, logprobs_2_response.shape)
self.assertEqual(logprobs_3_response.shape, logprobs_4_response.shape)
self.assertEqual(logprobs_1_response.shape, logprobs_2_response.shape)
self.assertEqual(response_1.logprobs.shape, logprobs_1_response.shape)
self.assertTrue(torch.allclose(response_1.logprobs, logprobs_1_response, rtol=0.5))
self.assertFalse(torch.allclose(response_1.logprobs, logprobs_2_response, rtol=0.5))
self.assertTrue(torch.allclose(response_2.logprobs, logprobs_4_response, rtol=0.8))
self.assertFalse(torch.allclose(response_2.logprobs, logprobs_3_response, rtol=0.8))


class TestAsyncAPIServer(RayUnittestBaseAysnc):
def setUp(self):
self.config = get_template_config()
Expand Down Expand Up @@ -403,8 +489,9 @@ async def test_api_async(self):
)
self.assertEqual(2, len(response.choices))
self.assertTrue(response.choices[0].logprobs is not None)
self.assertEqual(0, len(response.choices[0].logprobs.content[0].top_logprobs))
self.assertTrue(response.choices[0].logprobs.content[0].logprob < 0)
self.assertEqual(0, len(response.choices[0].logprobs.content[2].top_logprobs))
# here we check the 3rd token logprob, because the first two tokens (`<think>`,`\n` usually have zero logprob)
self.assertTrue(response.choices[0].logprobs.content[2].logprob < 0)
self.assertTrue(hasattr(response, "prompt_token_ids"))
self.assertTrue(len(response.prompt_token_ids) > 0)
self.assertTrue(hasattr(response.choices[0], "token_ids"))
Expand Down
43 changes: 33 additions & 10 deletions trinity/common/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,16 @@ async def chat(self, messages: List[dict], **kwargs) -> Sequence[Experience]:
"""Generate experiences from a list of history chat messages in async."""
raise NotImplementedError

async def logprobs(self, tokens: List[int]) -> Tensor:
async def logprobs(self, token_ids: List[int], **kwargs) -> Tensor:
"""Generate logprobs for a list of tokens in async."""
raise NotImplementedError

async def convert_messages_to_experience(self, messages: List[dict]) -> Experience:
async def convert_messages_to_experience(
self,
messages: List[dict],
tools: Optional[List[dict]] = None,
temperature: Optional[float] = None,
) -> Experience:
"""Convert a list of messages into an experience in async."""
raise NotImplementedError

Expand Down Expand Up @@ -205,21 +210,39 @@ async def chat_mm_async(
) -> List[Experience]:
return await self.model.chat_mm.remote(messages, images=images, videos=videos, **kwargs)

def logprobs(self, tokens: List[int]) -> Tensor:
def logprobs(self, tokens: List[int], temperature: Optional[float] = None) -> Tensor:
"""Calculate the logprobs of the given tokens."""
return ray.get(self.model.logprobs.remote(tokens))
return ray.get(self.model.logprobs.remote(tokens, temperature=temperature))

async def logprobs_async(self, tokens: List[int]) -> Tensor:
async def logprobs_async(
self, tokens: List[int], temperature: Optional[float] = None
) -> Tensor:
"""Calculate the logprobs of the given tokens in async."""
return await self.model.logprobs.remote(tokens)
return await self.model.logprobs.remote(tokens, temperature=temperature)

def convert_messages_to_experience(self, messages: List[dict]) -> Experience:
def convert_messages_to_experience(
self,
messages: List[dict],
tools: Optional[List[dict]] = None,
temperature: Optional[float] = None,
) -> Experience:
"""Convert a list of messages into an experience."""
return ray.get(self.model.convert_messages_to_experience.remote(messages))
return ray.get(
self.model.convert_messages_to_experience.remote(
messages, tools=tools, temperature=temperature
)
)

async def convert_messages_to_experience_async(self, messages: List[dict]) -> Experience:
async def convert_messages_to_experience_async(
self,
messages: List[dict],
tools: Optional[List[dict]] = None,
temperature: Optional[float] = None,
) -> Experience:
"""Convert a list of messages into an experience in async."""
return await self.model.convert_messages_to_experience.remote(messages)
return await self.model.convert_messages_to_experience.remote(
messages, tools=tools, temperature=temperature
)

@property
def model_version(self) -> int:
Expand Down
26 changes: 21 additions & 5 deletions trinity/common/models/vllm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@

from trinity.common.config import InferenceModelConfig
from trinity.common.experience import Experience
from trinity.common.models.api.vllm_patch import get_vllm_version
from trinity.common.models.mm_utils import (
build_multi_modal_inputs,
convert_messages_to_mm_format,
)
from trinity.common.models.model import InferenceModel
from trinity.common.models.utils import get_action_mask_method
from trinity.common.models.vllm_patch.api_patch import get_vllm_version
from trinity.utils.log import get_logger


Expand Down Expand Up @@ -100,6 +100,7 @@ def __init__(
},
disable_log_stats=True,
enable_lora=config.enable_lora,
logprobs_mode="processed_logprobs",
**config.lora_kwargs,
)
if get_vllm_version() > parse_version("0.10.0"):
Expand Down Expand Up @@ -307,25 +308,34 @@ async def generate_mm(
]
return experiences

async def logprobs(
self, token_ids: List[int], lora_request: LoRARequest = None
async def logprobs( # type: ignore [override]
self,
token_ids: List[int],
lora_request: LoRARequest = None,
temperature: Optional[float] = None,
) -> torch.Tensor:
"""Calculate the logprobs of the given tokens in async. Please slice the result carefully
to align with the actual response length.

Args:
token_ids (List[int]): The input token ids (seq_length). Please make sure the length of
it does not exceed `max_model_len - 1`.
lora_request (LoRARequest, optional): The LoRA request. Defaults to None.
temperature (float): The temperature for scaling logits.

Returns:
A tensor of logprobs (seq_length - 1).
"""
temperature = temperature if temperature is not None else self.config.temperature
if temperature is None:
temperature = 1.0
output = await self._generate_internal(
prompt={"prompt_token_ids": token_ids},
lora_request=lora_request,
n=1,
max_tokens=1,
prompt_logprobs=0, # vLLM return `prompt_logprobs + 1` logrpobs for each token
temperature=temperature,
)
return torch.tensor(
[list(logprob_dict.values())[0].logprob for logprob_dict in output.prompt_logprobs[1:]],
Expand Down Expand Up @@ -357,6 +367,7 @@ async def convert_messages_to_experience(
self,
messages: List[dict],
tools: Optional[List[dict]] = None,
temperature: Optional[float] = None,
) -> Experience:
"""Convert a list of messages into an experience."""
if self.tokenizer is None:
Expand All @@ -370,7 +381,10 @@ async def convert_messages_to_experience(
chat_template=self.chat_template,
enable_thinking=self.enable_thinking,
) # (seq_length, ), (seq_length, )
logprobs = await self.logprobs(token_ids=token_ids.tolist()) # (seq_length - 1,)
temperature = temperature if temperature is not None else self.config.temperature
logprobs = await self.logprobs(
token_ids=token_ids.tolist(), temperature=temperature
) # (seq_length - 1,)
return Experience(
tokens=token_ids,
logprobs=logprobs[prompt_length - 1 :],
Expand Down Expand Up @@ -481,7 +495,9 @@ async def run_api_server(self) -> bool:
if self.api_server_host is not None and self.api_server_port is not None:
return True # already running

from trinity.common.models.api.vllm_patch import run_api_server_in_ray_actor
from trinity.common.models.vllm_patch.api_patch import (
run_api_server_in_ray_actor,
)

api_server_host, api_server_port = self.get_available_address()
self.api_server = asyncio.create_task(
Expand Down
13 changes: 13 additions & 0 deletions trinity/common/models/vllm_patch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import vllm
from packaging.version import InvalidVersion
from packaging.version import parse as parse_version


def get_vllm_version():
try:
vllm_version = parse_version(vllm.__version__)
except InvalidVersion:
# for self-compiled vllm,
# we cannot parse the version, trait it as the lowest version we support
vllm_version = parse_version("0.8.5")
return vllm_version
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from typing import Optional, Union

import vllm
from packaging.version import InvalidVersion
from packaging.version import parse as parse_version
from pydantic import Field, TypeAdapter
from vllm.entrypoints.launcher import serve_http
Expand Down Expand Up @@ -39,6 +38,7 @@
from vllm.transformers_utils.tokenizer import MistralTokenizer
from vllm.utils import FlexibleArgumentParser, set_ulimit

from trinity.common.models.vllm_patch import get_vllm_version
from trinity.utils.log import get_logger


Expand Down Expand Up @@ -327,16 +327,6 @@ async def patch_and_serve_http(app, sock, args):
sock.close()


def get_vllm_version():
try:
vllm_version = parse_version(vllm.__version__)
except InvalidVersion:
# for self-compiled vllm,
# we cannot parse the version, trait it as the lowest version we support
vllm_version = parse_version("0.8.5")
return vllm_version


async def run_api_server_in_ray_actor(
async_llm,
host: str,
Expand Down
Loading