Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ requires-python = ">=3.10,<3.13"
dependencies = [
"verl==0.5.0",
"ray[default]>=2.48.0",
"vllm>=0.9.1,<=0.11.0",
"vllm>=0.10.2,<=0.11.0",
"tensordict",
"wandb",
"omegaconf",
Expand Down
83 changes: 83 additions & 0 deletions tests/common/vllm_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import unittest

import ray
import torch
from openai import BadRequestError
from parameterized import parameterized_class
Expand All @@ -11,6 +12,7 @@
get_model_path,
get_template_config,
)
from trinity.common.constants import ROLLOUT_WEIGHT_SYNC_GROUP_NAME
from trinity.common.models import create_inference_models
from trinity.common.models.model import ModelWrapper
from trinity.common.models.utils import (
Expand Down Expand Up @@ -361,6 +363,87 @@ async def test_api(self):
self.assertEqual(len(self.model_wrapper_no_history.history), 0)


class DummySynchronizer:
def __init__(self):
pass

def do_nothing(self):
pass


class TestLogprobs(RayUnittestBaseAysnc):
def setUp(self):
self.config = get_template_config()
self.config.mode = "explore"
self.config.model.model_path = get_model_path()
self.config.explorer.rollout_model.engine_type = "vllm"
self.config.explorer.rollout_model.engine_num = 1
self.config.explorer.rollout_model.tensor_parallel_size = 1
self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE
self.config.explorer.rollout_model.enable_openai_api = True

self.config.check_and_update()
self.engines, self.auxiliary_engines = create_inference_models(self.config)
self.model_wrapper = ModelWrapper(self.engines[0], engine_type="vllm", enable_history=True)

async def test_logprobs(self):
# use init process group to apply patches
sync = (
ray.remote(DummySynchronizer)
.options(name="synchronizer", namespace=self.config.ray_namespace)
.remote()
)
await sync.__ray_ready__.remote()
await self.model_wrapper.prepare()
master_address, master_port = await self.engines[0].get_available_address.remote()
await self.engines[0].init_process_group.remote(
master_address,
master_port,
world_size=1,
rank_offset=0,
group_name=ROLLOUT_WEIGHT_SYNC_GROUP_NAME,
explorer_name=self.config.explorer.name,
timeout=20,
)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is your name?"},
]
response_1 = self.model_wrapper.chat(messages, n=1, temperature=1.0, logprobs=True)[0]
response_2 = self.model_wrapper.chat(messages, n=1, temperature=0.5, logprobs=True)[0]
self.assertTrue(response_1.logprobs is not None)
self.assertTrue(len(response_1.logprobs) > 0)
self.assertTrue(response_2.logprobs is not None)
self.assertTrue(len(response_2.logprobs) > 0)
logprobs_1 = self.model_wrapper.logprobs(response_1.tokens.tolist(), temperature=1.0)
logprobs_2 = self.model_wrapper.logprobs(response_1.tokens.tolist(), temperature=0.5)
logprobs_3 = self.model_wrapper.logprobs(response_2.tokens.tolist(), temperature=1.0)
logprobs_4 = self.model_wrapper.logprobs(response_2.tokens.tolist(), temperature=0.5)
self.assertEqual(logprobs_1.shape, logprobs_2.shape)
self.assertFalse(torch.equal(logprobs_1, logprobs_2))
logprobs_1_prompt = logprobs_1[: response_1.prompt_length - 1]
logprobs_2_prompt = logprobs_2[: response_1.prompt_length - 1]
logprobs_3_prompt = logprobs_3[: response_2.prompt_length - 1]
logprobs_4_prompt = logprobs_4[: response_2.prompt_length - 1]
self.assertEqual(logprobs_1_prompt.shape, logprobs_2_prompt.shape)
self.assertFalse(torch.isclose(logprobs_1_prompt, logprobs_2_prompt, atol=1e-2).all())
self.assertFalse(torch.isclose(logprobs_3_prompt, logprobs_4_prompt, atol=1e-2).all())
self.assertTrue(torch.isclose(logprobs_1_prompt, logprobs_3_prompt, atol=1e-2).all())
self.assertTrue(torch.isclose(logprobs_2_prompt, logprobs_4_prompt, atol=1e-2).all())
logprobs_1_response = logprobs_1[response_1.prompt_length - 1 :]
logprobs_2_response = logprobs_2[response_1.prompt_length - 1 :]
logprobs_3_response = logprobs_3[response_2.prompt_length - 1 :]
logprobs_4_response = logprobs_4[response_2.prompt_length - 1 :]
self.assertEqual(logprobs_1_response.shape, logprobs_2_response.shape)
self.assertEqual(logprobs_3_response.shape, logprobs_4_response.shape)
self.assertEqual(logprobs_1_response.shape, logprobs_2_response.shape)
self.assertEqual(response_1.logprobs.shape, logprobs_1_response.shape)
self.assertTrue(torch.isclose(response_1.logprobs, logprobs_1_response, atol=1e-2).all())
self.assertFalse(torch.isclose(response_1.logprobs, logprobs_2_response, atol=1e-2).all())
self.assertTrue(torch.isclose(response_2.logprobs, logprobs_4_response, atol=1e-2).all())
self.assertFalse(torch.isclose(response_2.logprobs, logprobs_3_response, atol=1e-2).all())


class TestAsyncAPIServer(RayUnittestBaseAysnc):
def setUp(self):
self.config = get_template_config()
Expand Down
43 changes: 33 additions & 10 deletions trinity/common/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,16 @@ async def chat(self, messages: List[dict], **kwargs) -> Sequence[Experience]:
"""Generate experiences from a list of history chat messages in async."""
raise NotImplementedError

async def logprobs(self, tokens: List[int]) -> Tensor:
async def logprobs(self, token_ids: List[int], **kwargs) -> Tensor:
"""Generate logprobs for a list of tokens in async."""
raise NotImplementedError

async def convert_messages_to_experience(self, messages: List[dict]) -> Experience:
async def convert_messages_to_experience(
self,
messages: List[dict],
tools: Optional[List[dict]] = None,
temperature: Optional[float] = None,
) -> Experience:
"""Convert a list of messages into an experience in async."""
raise NotImplementedError

Expand Down Expand Up @@ -205,21 +210,39 @@ async def chat_mm_async(
) -> List[Experience]:
return await self.model.chat_mm.remote(messages, images=images, videos=videos, **kwargs)

def logprobs(self, tokens: List[int]) -> Tensor:
def logprobs(self, tokens: List[int], temperature: Optional[float] = None) -> Tensor:
"""Calculate the logprobs of the given tokens."""
return ray.get(self.model.logprobs.remote(tokens))
return ray.get(self.model.logprobs.remote(tokens, temperature=temperature))

async def logprobs_async(self, tokens: List[int]) -> Tensor:
async def logprobs_async(
self, tokens: List[int], temperature: Optional[float] = None
) -> Tensor:
"""Calculate the logprobs of the given tokens in async."""
return await self.model.logprobs.remote(tokens)
return await self.model.logprobs.remote(tokens, temperature=temperature)

def convert_messages_to_experience(self, messages: List[dict]) -> Experience:
def convert_messages_to_experience(
self,
messages: List[dict],
tools: Optional[List[dict]] = None,
temperature: Optional[float] = None,
) -> Experience:
"""Convert a list of messages into an experience."""
return ray.get(self.model.convert_messages_to_experience.remote(messages))
return ray.get(
self.model.convert_messages_to_experience.remote(
messages, tools=tools, temperature=temperature
)
)

async def convert_messages_to_experience_async(self, messages: List[dict]) -> Experience:
async def convert_messages_to_experience_async(
self,
messages: List[dict],
tools: Optional[List[dict]] = None,
temperature: Optional[float] = None,
) -> Experience:
"""Convert a list of messages into an experience in async."""
return await self.model.convert_messages_to_experience.remote(messages)
return await self.model.convert_messages_to_experience.remote(
messages, tools=tools, temperature=temperature
)

@property
def model_version(self) -> int:
Expand Down
26 changes: 21 additions & 5 deletions trinity/common/models/vllm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@

from trinity.common.config import InferenceModelConfig
from trinity.common.experience import Experience
from trinity.common.models.api.vllm_patch import get_vllm_version
from trinity.common.models.mm_utils import (
build_multi_modal_inputs,
convert_messages_to_mm_format,
)
from trinity.common.models.model import InferenceModel
from trinity.common.models.utils import get_action_mask_method
from trinity.common.models.vllm_patch.api_patch import get_vllm_version
from trinity.utils.log import get_logger


Expand Down Expand Up @@ -100,6 +100,7 @@ def __init__(
},
disable_log_stats=True,
enable_lora=config.enable_lora,
logprobs_mode="processed_logprobs",
**config.lora_kwargs,
)
if get_vllm_version() > parse_version("0.10.0"):
Expand Down Expand Up @@ -307,25 +308,34 @@ async def generate_mm(
]
return experiences

async def logprobs(
self, token_ids: List[int], lora_request: LoRARequest = None
async def logprobs( # type: ignore [override]
self,
token_ids: List[int],
lora_request: LoRARequest = None,
temperature: Optional[float] = None,
) -> torch.Tensor:
"""Calculate the logprobs of the given tokens in async. Please slice the result carefully
to align with the actual response length.

Args:
token_ids (List[int]): The input token ids (seq_length). Please make sure the length of
it does not exceed `max_model_len - 1`.
lora_request (LoRARequest, optional): The LoRA request. Defaults to None.
temperature (float): The temperature for scaling logits.

Returns:
A tensor of logprobs (seq_length - 1).
"""
temperature = temperature if temperature is not None else self.config.temperature
if temperature is None:
temperature = 1.0
output = await self._generate_internal(
prompt={"prompt_token_ids": token_ids},
lora_request=lora_request,
n=1,
max_tokens=1,
prompt_logprobs=0, # vLLM return `prompt_logprobs + 1` logrpobs for each token
temperature=temperature,
)
return torch.tensor(
[list(logprob_dict.values())[0].logprob for logprob_dict in output.prompt_logprobs[1:]],
Expand Down Expand Up @@ -357,6 +367,7 @@ async def convert_messages_to_experience(
self,
messages: List[dict],
tools: Optional[List[dict]] = None,
temperature: Optional[float] = None,
) -> Experience:
"""Convert a list of messages into an experience."""
if self.tokenizer is None:
Expand All @@ -370,7 +381,10 @@ async def convert_messages_to_experience(
chat_template=self.chat_template,
enable_thinking=self.enable_thinking,
) # (seq_length, ), (seq_length, )
logprobs = await self.logprobs(token_ids=token_ids.tolist()) # (seq_length - 1,)
temperature = temperature if temperature is not None else self.config.temperature
logprobs = await self.logprobs(
token_ids=token_ids.tolist(), temperature=temperature
) # (seq_length - 1,)
return Experience(
tokens=token_ids,
logprobs=logprobs[prompt_length - 1 :],
Expand Down Expand Up @@ -481,7 +495,9 @@ async def run_api_server(self) -> bool:
if self.api_server_host is not None and self.api_server_port is not None:
return True # already running

from trinity.common.models.api.vllm_patch import run_api_server_in_ray_actor
from trinity.common.models.vllm_patch.api_patch import (
run_api_server_in_ray_actor,
)

api_server_host, api_server_port = self.get_available_address()
self.api_server = asyncio.create_task(
Expand Down
13 changes: 13 additions & 0 deletions trinity/common/models/vllm_patch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import vllm
from packaging.version import InvalidVersion
from packaging.version import parse as parse_version


def get_vllm_version():
try:
vllm_version = parse_version(vllm.__version__)
except InvalidVersion:
# for self-compiled vllm,
# we cannot parse the version, trait it as the lowest version we support
vllm_version = parse_version("0.8.5")
return vllm_version
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from typing import Optional, Union

import vllm
from packaging.version import InvalidVersion
from packaging.version import parse as parse_version
from pydantic import Field, TypeAdapter
from vllm.entrypoints.launcher import serve_http
Expand Down Expand Up @@ -39,6 +38,7 @@
from vllm.transformers_utils.tokenizer import MistralTokenizer
from vllm.utils import FlexibleArgumentParser, set_ulimit

from trinity.common.models.vllm_patch import get_vllm_version
from trinity.utils.log import get_logger


Expand Down Expand Up @@ -327,16 +327,6 @@ async def patch_and_serve_http(app, sock, args):
sock.close()


def get_vllm_version():
try:
vllm_version = parse_version(vllm.__version__)
except InvalidVersion:
# for self-compiled vllm,
# we cannot parse the version, trait it as the lowest version we support
vllm_version = parse_version("0.8.5")
return vllm_version


async def run_api_server_in_ray_actor(
async_llm,
host: str,
Expand Down
Loading