Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ requires-python = ">=3.10,<3.13"
dependencies = [
"verl==0.5.0",
"ray[default]>=2.48.0",
"vllm>=0.9.1,<=0.11.0",
"vllm>=0.10.0,<=0.11.0",
"tensordict",
"wandb",
"omegaconf",
Expand Down
6 changes: 4 additions & 2 deletions trinity/common/models/vllm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@

from trinity.common.config import InferenceModelConfig
from trinity.common.experience import Experience
from trinity.common.models.api.vllm_patch import get_vllm_version
from trinity.common.models.mm_utils import (
build_multi_modal_inputs,
convert_messages_to_mm_format,
)
from trinity.common.models.model import InferenceModel
from trinity.common.models.utils import get_action_mask_method
from trinity.common.models.vllm_patch.api_patch import get_vllm_version
from trinity.utils.log import get_logger


Expand Down Expand Up @@ -481,7 +481,9 @@ async def run_api_server(self) -> bool:
if self.api_server_host is not None and self.api_server_port is not None:
return True # already running

from trinity.common.models.api.vllm_patch import run_api_server_in_ray_actor
from trinity.common.models.vllm_patch.api_patch import (
run_api_server_in_ray_actor,
)

api_server_host, api_server_port = self.get_available_address()
self.api_server = asyncio.create_task(
Expand Down
13 changes: 13 additions & 0 deletions trinity/common/models/vllm_patch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import vllm
from packaging.version import InvalidVersion
from packaging.version import parse as parse_version


def get_vllm_version():
try:
vllm_version = parse_version(vllm.__version__)
except InvalidVersion:
# for self-compiled vllm,
# we cannot parse the version, trait it as the lowest version we support
vllm_version = parse_version("0.8.5")
return vllm_version
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from typing import Optional, Union

import vllm
from packaging.version import InvalidVersion
from packaging.version import parse as parse_version
from pydantic import Field, TypeAdapter
from vllm.entrypoints.launcher import serve_http
Expand Down Expand Up @@ -39,6 +38,7 @@
from vllm.transformers_utils.tokenizer import MistralTokenizer
from vllm.utils import FlexibleArgumentParser, set_ulimit

from trinity.common.models.vllm_patch import get_vllm_version
from trinity.utils.log import get_logger


Expand Down Expand Up @@ -327,16 +327,6 @@ async def patch_and_serve_http(app, sock, args):
sock.close()


def get_vllm_version():
try:
vllm_version = parse_version(vllm.__version__)
except InvalidVersion:
# for self-compiled vllm,
# we cannot parse the version, trait it as the lowest version we support
vllm_version = parse_version("0.8.5")
return vllm_version


async def run_api_server_in_ray_actor(
async_llm,
host: str,
Expand Down
125 changes: 125 additions & 0 deletions trinity/common/models/vllm_patch/worker_patch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
from types import MethodType
from typing import Optional

import torch
import vllm
from packaging.version import parse as parse_version
from vllm.v1.outputs import LogprobsTensors
from vllm.v1.worker.gpu_model_runner import GPUModelRunner

from trinity.common.models.vllm_patch import get_vllm_version


def patch_vllm_prompt_logprobs(model_runner: GPUModelRunner):
"""Patch vLLM model runner to support prompt logprobs extraction."""
if get_vllm_version() < parse_version("0.10.0"):
raise ValueError(
f"Unsupported vllm version: {vllm.__version__}. "
"This patch requires vllm version >= 0.10.0, <= 0.11.0."
)

def _get_prompt_logprobs_dict(
self,
hidden_states: torch.Tensor,
num_scheduled_tokens: dict[str, int],
) -> dict[str, Optional[LogprobsTensors]]:
num_prompt_logprobs_dict = self.input_batch.num_prompt_logprobs
if not num_prompt_logprobs_dict:
return {}

in_progress_dict = self.input_batch.in_progress_prompt_logprobs_cpu
prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] = {}

# Since prompt logprobs are a rare feature, prioritize simple,
# maintainable loop over optimal performance.
completed_prefill_reqs = []
for req_id, num_prompt_logprobs in num_prompt_logprobs_dict.items():
num_tokens = num_scheduled_tokens[req_id]

# Get metadata for this request.
request = self.requests[req_id]
if request.prompt_token_ids is None:
# Prompt logprobs is incompatible with prompt embeddings
continue

num_prompt_tokens = len(request.prompt_token_ids)
prompt_token_ids = torch.tensor(request.prompt_token_ids).to(
self.device, non_blocking=True
)

# Set up target LogprobsTensors object.
logprobs_tensors = in_progress_dict.get(req_id)
if not logprobs_tensors:
# Create empty logprobs CPU tensors for the entire prompt.
# If chunked, we'll copy in slice by slice.
logprobs_tensors = LogprobsTensors.empty_cpu(
num_prompt_tokens - 1, num_prompt_logprobs + 1
)
in_progress_dict[req_id] = logprobs_tensors

# Determine number of logits to retrieve.
start_idx = request.num_computed_tokens
start_tok = start_idx + 1
num_remaining_tokens = num_prompt_tokens - start_tok
if num_tokens <= num_remaining_tokens:
# This is a chunk, more tokens remain.
# In the == case, there are no more prompt logprobs to produce
# but we want to defer returning them to the next step where we
# have new generated tokens to return.
num_logits = num_tokens
else:
# This is the last chunk of prompt tokens to return.
num_logits = num_remaining_tokens
completed_prefill_reqs.append(req_id)
prompt_logprobs_dict[req_id] = logprobs_tensors

if num_logits <= 0:
# This can happen for the final chunk if we prefilled exactly
# (num_prompt_tokens - 1) tokens for this request in the prior
# step. There are no more prompt logprobs to produce.
continue

# Get the logits corresponding to this req's prompt tokens.
# If this is a partial request (i.e. chunked prefill),
# then there is prompt logprob generated for each index.
req_idx = self.input_batch.req_id_to_index[req_id]
offset = self.query_start_loc.np[req_idx].item()
prompt_hidden_states = hidden_states[offset : offset + num_logits]
logits = self.model.compute_logits(prompt_hidden_states)

# PATCH START
temp = request.sampling_params.temperature
if temp is None or temp >= 1e-5:
logits.div_(temp)
# PATCH END

# Get the "target" tokens for each index. For prompt at index i,
# the token at prompt index i+1 is the "sampled" token we want
# to gather the logprob for.
tgt_token_ids = prompt_token_ids[start_tok : start_tok + num_logits]

# Compute prompt logprobs.
logprobs = self.sampler.compute_logprobs(logits)
token_ids, logprobs, ranks = self.sampler.gather_logprobs(
logprobs, num_prompt_logprobs, tgt_token_ids
)

# Transfer GPU->CPU async.
chunk_slice = slice(start_idx, start_idx + num_logits)
logprobs_tensors.logprob_token_ids[chunk_slice].copy_(token_ids, non_blocking=True)
logprobs_tensors.logprobs[chunk_slice].copy_(logprobs, non_blocking=True)
logprobs_tensors.selected_token_ranks[chunk_slice].copy_(ranks, non_blocking=True)

# Remove requests that have completed prefill from the batch
# num_prompt_logprobs_dict.
for req_id in completed_prefill_reqs:
del num_prompt_logprobs_dict[req_id]
del in_progress_dict[req_id]

# Must synchronize the non-blocking GPU->CPU transfers.
if prompt_logprobs_dict:
self._sync_device()

return prompt_logprobs_dict

model_runner._get_prompt_logprobs_dict = MethodType(_get_prompt_logprobs_dict, model_runner)
2 changes: 2 additions & 0 deletions trinity/common/models/vllm_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch.distributed
from verl.utils.vllm_utils import patch_vllm_moe_model_weight_loader

from trinity.common.models.vllm_patch.worker_patch import patch_vllm_prompt_logprobs
from trinity.manager.synchronizer import Synchronizer
from trinity.utils.distributed import init_process_group
from trinity.utils.log import get_logger
Expand Down Expand Up @@ -56,6 +57,7 @@ def init_process_group(
self.synchronizer = Synchronizer.get_actor(namespace=self._namespace)
self._checkpoint_converter = None
patch_vllm_moe_model_weight_loader(self.model_runner.model)
patch_vllm_prompt_logprobs(self.model_runner)

def update_weight(self):
"""Broadcast weight to all vllm workers from source rank 0 (actor model)"""
Expand Down