Skip to content

Commit a65cf0e

Browse files
authored
Implement Tinker compatible sample API (#456)
1 parent 9416705 commit a65cf0e

File tree

3 files changed

+176
-1
lines changed

3 files changed

+176
-1
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ dependencies = [
4343
"sortedcontainers",
4444
"word2number",
4545
"transformers",
46+
"tinker",
4647
]
4748

4849
[project.scripts]

tests/common/vllm_test.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1219,3 +1219,91 @@ async def test_generate(self):
12191219
response.prompt_length, 40960
12201220
) # If not long enough, please add more files to prompt
12211221
self.assertGreater(response.logprobs.shape[0], 1000)
1222+
1223+
1224+
class TestTinkerAPI(RayUnittestBaseAysnc):
1225+
"""Test the Tinker API integration with the vLLM engine."""
1226+
1227+
def setUp(self):
1228+
self.config = get_template_config()
1229+
self.config.mode = "explore"
1230+
self.config.model.model_path = get_model_path()
1231+
self.config.explorer.rollout_model.engine_type = "vllm"
1232+
self.config.explorer.rollout_model.engine_num = 1
1233+
self.config.explorer.rollout_model.tensor_parallel_size = 1
1234+
self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE
1235+
self.config.explorer.rollout_model.enable_openai_api = True
1236+
1237+
self.config.check_and_update()
1238+
self.engines, self.auxiliary_engines = create_inference_models(self.config)
1239+
self.model_wrapper = ModelWrapper(self.engines[0], engine_type="vllm", enable_history=True)
1240+
1241+
async def test_tinker_api(self):
1242+
from tinker import types
1243+
from transformers import AutoTokenizer
1244+
1245+
engine = self.engines[0]
1246+
tokenizer = AutoTokenizer.from_pretrained(self.config.model.model_path)
1247+
messages = [
1248+
{"role": "system", "content": "You are a helpful assistant."},
1249+
{"role": "user", "content": "What is your name?"},
1250+
]
1251+
result_dict = tokenizer.apply_chat_template(
1252+
messages,
1253+
chat_template=CHAT_TEMPLATE,
1254+
add_generation_prompt=False,
1255+
padding=False,
1256+
truncation=True,
1257+
return_tensors="pt",
1258+
add_special_tokens=False,
1259+
return_assistant_tokens_mask=True,
1260+
return_dict=True,
1261+
)
1262+
prompt = types.ModelInput.from_ints(
1263+
result_dict["input_ids"][0].tolist(),
1264+
)
1265+
# sample api without prompt logprobs
1266+
num_samples = 4
1267+
response = await engine.sample.remote(
1268+
prompt=prompt,
1269+
num_samples=num_samples,
1270+
sampling_params=types.SamplingParams(temperature=0.7), # no limit on length
1271+
)
1272+
self.assertEqual(len(response.sequences), num_samples)
1273+
for sequence in response.sequences:
1274+
self.assertEqual(len(sequence.tokens), len(sequence.logprobs))
1275+
self.assertEqual(sequence.stop_reason, "stop")
1276+
self.assertIsNone(response.prompt_logprobs)
1277+
self.assertIsNone(response.topk_prompt_logprobs)
1278+
# sample api with prompt logprobs
1279+
num_samples = 2
1280+
topk_prompt_logprobs = 3
1281+
response = await engine.sample.remote(
1282+
prompt=prompt,
1283+
num_samples=num_samples,
1284+
sampling_params=types.SamplingParams(temperature=0.7, max_tokens=8),
1285+
include_prompt_logprobs=True,
1286+
topk_prompt_logprobs=topk_prompt_logprobs,
1287+
)
1288+
self.assertEqual(len(response.sequences), num_samples)
1289+
for sequence in response.sequences:
1290+
self.assertEqual(len(sequence.tokens), len(sequence.logprobs))
1291+
self.assertEqual(sequence.stop_reason, "length")
1292+
self.assertEqual(len(response.prompt_logprobs), len(prompt.to_ints()))
1293+
self.assertIsNone(response.prompt_logprobs[0])
1294+
self.assertEqual(len(response.topk_prompt_logprobs), len(prompt.to_ints()))
1295+
self.assertIsNone(response.topk_prompt_logprobs[0])
1296+
for topk_logprobs in response.topk_prompt_logprobs[1:]:
1297+
self.assertIsNotNone(topk_logprobs)
1298+
self.assertEqual(len(topk_logprobs), topk_prompt_logprobs)
1299+
# compute_logprob api
1300+
response = await engine.sample.remote(
1301+
prompt=prompt,
1302+
num_samples=1,
1303+
sampling_params=types.SamplingParams(max_tokens=1),
1304+
include_prompt_logprobs=True,
1305+
)
1306+
self.assertEqual(len(response.sequences), 1)
1307+
self.assertEqual(response.sequences[0].stop_reason, "length")
1308+
self.assertEqual(len(prompt.to_ints()), len(response.prompt_logprobs))
1309+
self.assertIsNone(response.topk_prompt_logprobs)

trinity/common/models/vllm_model.py

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import asyncio
44
import os
55
from collections import defaultdict
6-
from typing import Any, Dict, List, Optional, Sequence
6+
from typing import Any, Dict, List, Optional, Sequence, Tuple
77

88
import numpy as np
99
import ray
@@ -402,6 +402,92 @@ async def logprobs( # type: ignore [override]
402402
dtype=torch.float32,
403403
)
404404

405+
async def sample(
406+
self,
407+
prompt: Any,
408+
num_samples: int,
409+
sampling_params: Any,
410+
include_prompt_logprobs: bool = False,
411+
topk_prompt_logprobs: int = 0,
412+
lora_request: Optional[Any] = None,
413+
) -> Any:
414+
"""Tinker compatible sampling interface.
415+
416+
Args:
417+
prompt (ModelInput): The input prompt.
418+
num_samples (int): The number of samples to generate.
419+
sampling_params (SamplingParams): The sampling parameters.
420+
include_prompt_logprobs (bool): Whether to include prompt logprobs.
421+
topk_prompt_logprobs (int): The top-k prompt logprobs to include.
422+
lora_request (LoRARequest, optional): The LoRA request. Defaults to None.
423+
Returns:
424+
SampleResponse: The sample response.
425+
"""
426+
from tinker.types import SampledSequence, SampleResponse
427+
428+
params = {
429+
"max_tokens": sampling_params.max_tokens
430+
if sampling_params.max_tokens is not None
431+
else self.config.max_response_tokens,
432+
"seed": sampling_params.seed if sampling_params.seed is not None else self.config.seed,
433+
"top_k": sampling_params.top_k,
434+
"top_p": sampling_params.top_p,
435+
"temperature": sampling_params.temperature,
436+
"n": num_samples,
437+
"prompt_logprobs": (topk_prompt_logprobs if include_prompt_logprobs else None),
438+
# in vLLM, 0 means only return the chosen token's logprob
439+
"logprobs": 0,
440+
}
441+
if sampling_params.stop is not None:
442+
params["stop"] = sampling_params.stop
443+
req_output = await self._generate_internal(
444+
prompt={"prompt_token_ids": prompt.to_ints()},
445+
lora_request=lora_request,
446+
**params,
447+
)
448+
sequences = []
449+
# vLLM's prompt_logprobs output does not include a value for the first token.
450+
# Initialize with [None] to align with the prompt tokens.
451+
topk_prompt_logprobs_list: List[Optional[List[Tuple[int, float]]]] = [None]
452+
prompt_logprobs: List[Optional[float]] = [None]
453+
454+
# collect prompt logprobs
455+
if include_prompt_logprobs:
456+
for logprob_dict in req_output.prompt_logprobs[1:]:
457+
prompt_logprobs.append(next(iter(logprob_dict.values())).logprob)
458+
if topk_prompt_logprobs > 0:
459+
# collect top-k prompt logprobs
460+
# logprob_dict: {token_id: Logprob(logprob, rank, ...), ...}
461+
logprob_items = list(logprob_dict.items())
462+
# sort by Logprob.rank
463+
logprob_items_sorted = sorted(logprob_items, key=lambda x: x[1].rank)
464+
# pick topk
465+
topk = logprob_items_sorted[:topk_prompt_logprobs]
466+
# record as (token_id, logprob)
467+
topk_prompt_logprobs_list.append(
468+
[(token_id, logprob.logprob) for token_id, logprob in topk]
469+
)
470+
# collect response sequences
471+
for seq_output in req_output.outputs:
472+
seq = SampledSequence(
473+
stop_reason="length" if seq_output.finish_reason == "length" else "stop",
474+
tokens=seq_output.token_ids,
475+
logprobs=[
476+
next(iter(logprob_dict.values())).logprob
477+
for logprob_dict in seq_output.logprobs
478+
],
479+
)
480+
sequences.append(seq)
481+
return SampleResponse(
482+
sequences=sequences,
483+
prompt_logprobs=prompt_logprobs if include_prompt_logprobs else None,
484+
topk_prompt_logprobs=(
485+
topk_prompt_logprobs_list
486+
if include_prompt_logprobs and topk_prompt_logprobs > 0
487+
else None
488+
),
489+
)
490+
405491
async def _generate_internal(self, prompt: Any, lora_request=None, **kwargs) -> Any:
406492
# Send the request to the LLM engine.
407493
self.request_id += 1

0 commit comments

Comments
 (0)