Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ dependencies = [
"sortedcontainers",
"word2number",
"transformers",
"tinker",
]

[project.scripts]
Expand Down
89 changes: 89 additions & 0 deletions tests/common/vllm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1219,3 +1219,92 @@ async def test_generate(self):
response.prompt_length, 40960
) # If not long enough, please add more files to prompt
self.assertGreater(response.logprobs.shape[0], 1000)


class TestTinkerAPI(RayUnittestBaseAysnc):
"""Test the Tinker API integration with the vLLM engine."""

def setUp(self):
self.config = get_template_config()
self.config.mode = "explore"
self.config.model.model_path = get_model_path()
self.config.explorer.rollout_model.engine_type = "vllm"
self.config.explorer.rollout_model.engine_num = 1
self.config.explorer.rollout_model.tensor_parallel_size = 1
self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE
self.config.explorer.rollout_model.enable_openai_api = True

self.config.check_and_update()
self.engines, self.auxiliary_engines = create_inference_models(self.config)
self.model_wrapper = ModelWrapper(self.engines[0], engine_type="vllm", enable_history=True)

async def test_tinker_api(self):
from tinker import types
from transformers import AutoTokenizer

engine = self.engines[0]
tokenizer = AutoTokenizer.from_pretrained(self.config.model.model_path)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is your name?"},
]
result_dict = tokenizer.apply_chat_template(
messages,
chat_template=CHAT_TEMPLATE,
add_generation_prompt=False,
padding=False,
truncation=True,
return_tensors="pt",
add_special_tokens=False,
return_assistant_tokens_mask=True,
return_dict=True,
)
prompt = types.ModelInput.from_ints(
result_dict["input_ids"][0].tolist(),
)
# sample api without prompt logprobs
num_samples = 4
response = await engine.sample.remote(
prompt=prompt,
num_samples=num_samples,
sampling_params=types.SamplingParams(temperature=0.7), # no limit on length
)
self.assertEqual(len(response.sequences), num_samples)
for sequence in response.sequences:
print("response length:", len(sequence.tokens))
self.assertEqual(len(sequence.tokens), len(sequence.logprobs))
self.assertEqual(sequence.stop_reason, "stop")
self.assertIsNone(response.prompt_logprobs)
self.assertIsNone(response.topk_prompt_logprobs)
# sample api with prompt logprobs
num_samples = 2
topk_prompt_logprobs = 3
response = await engine.sample.remote(
prompt=prompt,
num_samples=num_samples,
sampling_params=types.SamplingParams(temperature=0.7, max_tokens=8),
include_prompt_logprobs=True,
topk_prompt_logprobs=topk_prompt_logprobs,
)
self.assertEqual(len(response.sequences), num_samples)
for sequence in response.sequences:
self.assertEqual(len(sequence.tokens), len(sequence.logprobs))
self.assertEqual(sequence.stop_reason, "length")
self.assertEqual(len(response.prompt_logprobs), len(prompt.to_ints()))
self.assertIsNone(response.prompt_logprobs[0])
self.assertEqual(len(response.topk_prompt_logprobs), len(prompt.to_ints()))
self.assertIsNone(response.topk_prompt_logprobs[0])
for topk_logprobs in response.topk_prompt_logprobs[1:]:
self.assertIsNotNone(topk_logprobs)
self.assertEqual(len(topk_logprobs), topk_prompt_logprobs)
# compute_logprob api
response = await engine.sample.remote(
prompt=prompt,
num_samples=1,
sampling_params=types.SamplingParams(max_tokens=1),
include_prompt_logprobs=True,
)
self.assertEqual(len(response.sequences), 1)
self.assertEqual(response.sequences[0].stop_reason, "length")
self.assertEqual(len(prompt.to_ints()), len(response.prompt_logprobs))
self.assertIsNone(response.topk_prompt_logprobs)
73 changes: 71 additions & 2 deletions trinity/common/models/vllm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
import asyncio
import os
from collections import defaultdict
from typing import Any, Dict, List, Optional, Sequence
from typing import Any, Dict, List, Optional, Sequence, Tuple

import numpy as np
import ray
import torch
from packaging.version import parse as parse_version
from PIL import Image
from tinker import types
from transformers import AutoProcessor

from trinity.common.config import InferenceModelConfig
Expand Down Expand Up @@ -402,6 +403,74 @@ async def logprobs( # type: ignore [override]
dtype=torch.float32,
)

async def sample(
self,
prompt: types.ModelInput,
num_samples: int,
sampling_params: types.SamplingParams,
include_prompt_logprobs: bool = False,
topk_prompt_logprobs: int = 0,
lora_request=None,
) -> types.SampleResponse:
"""Tinker compatible sampling interface."""
params = {
"max_tokens": sampling_params.max_tokens or self.config.max_response_tokens,
"seed": sampling_params.seed or self.config.seed,
"top_k": sampling_params.top_k or self.config.top_k,
"top_p": sampling_params.top_p or self.config.top_p,
"temperature": sampling_params.temperature or self.config.temperature,
"n": num_samples,
"prompt_logprobs": (topk_prompt_logprobs if include_prompt_logprobs else None),
# in vLLM, 0 means only return the chosen token's logprob
"logprobs": 0,
}
if sampling_params.stop is not None:
params["stop"] = sampling_params.stop
req_output = await self._generate_internal(
prompt={"prompt_token_ids": prompt.to_ints()},
lora_request=lora_request,
**params,
)
sequences = []
topk_prompt_logprobs_list: List[Optional[List[Tuple[int, float]]]] = [None]
prompt_logprobs: List[Optional[float]] = [None]

# collect prompt logprobs
if include_prompt_logprobs:
for logprob_dict in req_output.prompt_logprobs[1:]:
prompt_logprobs.append(list(logprob_dict.values())[0].logprob)
if topk_prompt_logprobs > 0:
# collect top-k prompt logprobs
# logprob_dict: {token_id: Logprob(logprob, rank, ...), ...}
logprob_items = list(logprob_dict.items())
# sort by Logprob.rank
logprob_items_sorted = sorted(logprob_items, key=lambda x: x[1].rank)
# pick topk
topk = logprob_items_sorted[:topk_prompt_logprobs]
# record as (token_id, logprob)
topk_prompt_logprobs_list.append(
[(token_id, logprob.logprob) for token_id, logprob in topk]
)
# collect response sequences
for seq_output in req_output.outputs:
seq = types.SampledSequence(
stop_reason="length" if seq_output.finish_reason == "length" else "stop",
tokens=seq_output.token_ids,
logprobs=[
list(logprob_dict.values())[0].logprob for logprob_dict in seq_output.logprobs
],
)
sequences.append(seq)
return types.SampleResponse(
sequences=sequences,
prompt_logprobs=prompt_logprobs if include_prompt_logprobs else None,
topk_prompt_logprobs=(
topk_prompt_logprobs_list
if include_prompt_logprobs and topk_prompt_logprobs > 0
else None
),
)

async def _generate_internal(self, prompt: Any, lora_request=None, **kwargs) -> Any:
# Send the request to the LLM engine.
self.request_id += 1
Expand Down Expand Up @@ -447,7 +516,7 @@ async def convert_messages_to_experience(
if len(token_ids) > self.config.max_model_len - 1:
truncate_status = "response_truncated"
self.logger.warning(
f"Warning: {len(token_ids) = } exceeds the length limit {self.config.max_model_len-1 = }"
f"Warning: {len(token_ids)=} exceeds the length limit {self.config.max_model_len - 1=}"
)
token_ids = token_ids[: self.config.max_model_len - 1]
action_mask = action_mask[: self.config.max_model_len - 1]
Expand Down