Skip to content

Commit 154bd71

Browse files
authored
Fix the mismatch between vLLM OpenAI API and vLLM generate (#431)
1 parent 613194d commit 154bd71

File tree

11 files changed

+254
-79
lines changed

11 files changed

+254
-79
lines changed

tests/common/vllm_test.py

Lines changed: 169 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1+
import asyncio
12
import os
23
import unittest
34

4-
import ray
55
import torch
66
from openai import BadRequestError
77
from parameterized import parameterized_class
@@ -13,7 +13,6 @@
1313
get_model_path,
1414
get_template_config,
1515
)
16-
from trinity.common.constants import ROLLOUT_WEIGHT_SYNC_GROUP_NAME
1716
from trinity.common.models import create_inference_models
1817
from trinity.common.models.model import ModelWrapper
1918
from trinity.common.models.utils import (
@@ -29,6 +28,16 @@ def print_debug(*args):
2928
print(*args)
3029

3130

31+
async def prepare_engines(engines, auxiliary_engines):
32+
prepare_model_refs = []
33+
for engine in engines:
34+
prepare_model_refs.append(engine.prepare.remote())
35+
for engines in auxiliary_engines:
36+
for engine in engines:
37+
prepare_model_refs.append(engine.prepare.remote())
38+
await asyncio.gather(*prepare_model_refs)
39+
40+
3241
# Qwen2.5 chat template with {% generation %} mark
3342
CHAT_TEMPLATE = r"""
3443
{%- if tools %}
@@ -127,6 +136,7 @@ def setUp(self):
127136
async def test_generate(
128137
self,
129138
):
139+
await prepare_engines(self.engines, self.auxiliary_engines)
130140
await self.model_wrapper.prepare()
131141
self.assertEqual(self.model_wrapper.model_path, self.config.model.model_path)
132142
self.assertEqual(await self.model_wrapper.model_path_async, self.config.model.model_path)
@@ -244,6 +254,7 @@ def setUp(self):
244254
self.tokenizer = AutoTokenizer.from_pretrained(self.config.model.model_path)
245255

246256
async def test_model_len(self):
257+
await prepare_engines(self.engines, self.auxiliary_engines)
247258
await self.model_wrapper.prepare()
248259
messages = [
249260
{"role": "system", "content": "You are a helpful assistant."},
@@ -311,6 +322,7 @@ def setUp(self):
311322
self.model_wrapper = ModelWrapper(self.engines[0], engine_type="vllm", enable_history=True)
312323

313324
async def test_model_len(self):
325+
await prepare_engines(self.engines, self.auxiliary_engines)
314326
await self.model_wrapper.prepare()
315327
messages = [
316328
{"role": "user", "content": "How are you?"},
@@ -362,6 +374,7 @@ def setUp(self):
362374
)
363375

364376
async def test_api(self):
377+
await prepare_engines(self.engines, self.auxiliary_engines)
365378
await self.model_wrapper.prepare()
366379
await self.model_wrapper_no_history.prepare()
367380
openai_client = self.model_wrapper.get_openai_client()
@@ -435,12 +448,42 @@ async def test_api(self):
435448
self.assertEqual(len(self.model_wrapper_no_history.history), 0)
436449

437450

438-
class DummySynchronizer:
439-
def __init__(self):
440-
pass
451+
SYSTEM_PROMPT = """
452+
You are Qwen, created by Alibaba Cloud. You are a helpful assistant. You are walking on a frozen lake.
453+
454+
FrozenLake Quick Guide
455+
Goal: Reach the goal (G). Player (P) and Goal (G) must overlap.
456+
457+
Symbols:
458+
_ Frozen | O Hole | G Goal | P Player
459+
460+
Rules:
461+
1. Avoid falling into holes (O).
462+
2. Frozen tiles are slippery, you may move perpendicular to your intended direction.
463+
464+
Valid Action (separated by | ):
465+
Up | Down | Left | Right
466+
467+
Rewards:
468+
Fall into hole: 0
469+
Reach goal: +1.0
470+
471+
You will be provided the current observation, please decide on the next Action.
472+
You should show your thought process and then input the final action in ``` ```.
473+
You should only output the NEXT ACTION at each interation in the ``` ```. For example, if you want to move up, you should output ```Up```.
474+
You should plan ahead and need to achieve it in minimum number of steps.
475+
You should be aware that frozen tiles can be slippery, but the chance is small and you should not overthink it.
476+
477+
Please show your thinking process and put the final action in ``` ```. In every turn, the final action MUST be one of Up, Down, Left, Right.
478+
"""
441479

442-
def do_nothing(self):
443-
pass
480+
USER_PROMPT = """Current Observation (0):
481+
_ G _
482+
_ _ _
483+
P O O
484+
You have not achieved the goal, P has not reached G yet. Please give the next action.
485+
The maximum number of steps remaining is 10.
486+
"""
444487

445488

446489
class TestLogprobs(RayUnittestBaseAysnc):
@@ -458,31 +501,76 @@ def setUp(self):
458501
self.engines, self.auxiliary_engines = create_inference_models(self.config)
459502
self.model_wrapper = ModelWrapper(self.engines[0], engine_type="vllm", enable_history=True)
460503

461-
async def test_logprobs(self):
462-
# use init process group to apply patches
463-
sync = (
464-
ray.remote(DummySynchronizer)
465-
.options(name="synchronizer", namespace=self.config.ray_namespace)
466-
.remote()
467-
)
468-
await sync.__ray_ready__.remote()
504+
async def test_logprobs_api(self):
505+
await prepare_engines(self.engines, self.auxiliary_engines)
469506
await self.model_wrapper.prepare()
470-
master_address, master_port = await self.engines[0].get_available_address.remote()
471-
await self.engines[0].init_process_group.remote(
472-
master_address,
473-
master_port,
474-
world_size=1,
475-
rank_offset=0,
476-
group_name=ROLLOUT_WEIGHT_SYNC_GROUP_NAME,
477-
explorer_name=self.config.explorer.name,
478-
timeout=20,
479-
)
480507
messages = [
481-
{"role": "system", "content": "You are a helpful assistant."},
482-
{"role": "user", "content": "What is your name?"},
508+
{"role": "system", "content": SYSTEM_PROMPT},
509+
{"role": "user", "content": USER_PROMPT},
483510
]
484-
response_1 = self.model_wrapper.chat(messages, n=1, temperature=1.0, logprobs=True)[0]
485-
response_2 = self.model_wrapper.chat(messages, n=1, temperature=0.8, logprobs=True)[0]
511+
512+
# Test openai api logprobs with different temperature
513+
514+
self.model_client = self.model_wrapper.get_openai_async_client()
515+
_ = await self.model_client.chat.completions.create(
516+
model=self.model_client.model_path,
517+
messages=messages,
518+
n=1,
519+
temperature=1.0,
520+
logprobs=True,
521+
max_tokens=15,
522+
)
523+
response_1 = self.model_wrapper.extract_experience_from_history()[0]
524+
_ = await self.model_client.chat.completions.create(
525+
model=self.model_client.model_path,
526+
messages=messages,
527+
n=1,
528+
temperature=0.8,
529+
logprobs=True,
530+
max_tokens=15,
531+
)
532+
response_2 = self.model_wrapper.extract_experience_from_history()[0]
533+
self.assertTrue(response_1.logprobs is not None)
534+
self.assertTrue(len(response_1.logprobs) > 0)
535+
self.assertTrue(response_2.logprobs is not None)
536+
self.assertTrue(len(response_2.logprobs) > 0)
537+
logprobs_1 = self.model_wrapper.logprobs(response_1.tokens.tolist(), temperature=1.0)
538+
logprobs_2 = self.model_wrapper.logprobs(response_1.tokens.tolist(), temperature=0.8)
539+
logprobs_3 = self.model_wrapper.logprobs(response_2.tokens.tolist(), temperature=1.0)
540+
logprobs_4 = self.model_wrapper.logprobs(response_2.tokens.tolist(), temperature=0.8)
541+
self.assertEqual(logprobs_1.shape, logprobs_2.shape)
542+
self.assertEqual(logprobs_3.shape, logprobs_4.shape)
543+
self.assertFalse(torch.allclose(logprobs_1, logprobs_2, rtol=0.4))
544+
self.assertFalse(torch.allclose(logprobs_3, logprobs_4, atol=0.4))
545+
logprobs_1_prompt = logprobs_1[: response_1.prompt_length - 1]
546+
logprobs_2_prompt = logprobs_2[: response_1.prompt_length - 1]
547+
logprobs_3_prompt = logprobs_3[: response_2.prompt_length - 1]
548+
logprobs_4_prompt = logprobs_4[: response_2.prompt_length - 1]
549+
self.assertEqual(logprobs_1_prompt.shape, logprobs_2_prompt.shape)
550+
self.assertFalse(torch.allclose(logprobs_1_prompt, logprobs_2_prompt, rtol=0.4))
551+
self.assertFalse(torch.allclose(logprobs_3_prompt, logprobs_4_prompt, rtol=0.4))
552+
self.assertTrue(torch.allclose(logprobs_1_prompt, logprobs_3_prompt, rtol=0.4))
553+
self.assertTrue(torch.allclose(logprobs_2_prompt, logprobs_4_prompt, rtol=0.4))
554+
logprobs_1_response = logprobs_1[response_1.prompt_length - 1 :]
555+
logprobs_2_response = logprobs_2[response_1.prompt_length - 1 :]
556+
logprobs_3_response = logprobs_3[response_2.prompt_length - 1 :]
557+
logprobs_4_response = logprobs_4[response_2.prompt_length - 1 :]
558+
self.assertEqual(logprobs_1_response.shape, logprobs_2_response.shape)
559+
self.assertEqual(logprobs_3_response.shape, logprobs_4_response.shape)
560+
self.assertEqual(logprobs_1_response.shape, logprobs_2_response.shape)
561+
self.assertEqual(response_1.logprobs.shape, logprobs_1_response.shape)
562+
self.assertTrue(torch.allclose(response_1.logprobs, logprobs_1_response, rtol=0.5))
563+
self.assertFalse(torch.allclose(response_1.logprobs, logprobs_2_response, rtol=0.5))
564+
self.assertTrue(torch.allclose(response_2.logprobs, logprobs_4_response, rtol=0.8))
565+
self.assertFalse(torch.allclose(response_2.logprobs, logprobs_3_response, rtol=0.8))
566+
567+
# test vllm engine logprobs with different temperature
568+
response_1 = self.model_wrapper.chat(
569+
messages, n=1, temperature=1.0, logprobs=True, max_tokens=15
570+
)[0]
571+
response_2 = self.model_wrapper.chat(
572+
messages, n=1, temperature=0.8, logprobs=True, max_tokens=15
573+
)[0]
486574
self.assertTrue(response_1.logprobs is not None)
487575
self.assertTrue(len(response_1.logprobs) > 0)
488576
self.assertTrue(response_2.logprobs is not None)
@@ -517,6 +605,56 @@ async def test_logprobs(self):
517605
self.assertTrue(torch.allclose(response_2.logprobs, logprobs_4_response, rtol=0.8))
518606
self.assertFalse(torch.allclose(response_2.logprobs, logprobs_3_response, rtol=0.8))
519607

608+
# test openai api and vllm engine logprobs consistency
609+
await self.model_wrapper.clean_workflow_state()
610+
_ = await self.model_client.chat.completions.create(
611+
model=self.model_client.model_path,
612+
messages=messages,
613+
n=1,
614+
temperature=1.0,
615+
logprobs=0,
616+
max_tokens=1,
617+
)
618+
response_openai_1 = self.model_wrapper.extract_experience_from_history()[0]
619+
_ = await self.model_client.chat.completions.create(
620+
model=self.model_client.model_path,
621+
messages=messages,
622+
n=1,
623+
temperature=0.8,
624+
logprobs=0,
625+
max_tokens=1,
626+
)
627+
response_openai_2 = self.model_wrapper.extract_experience_from_history()[0]
628+
response_vllm_1 = self.model_wrapper.chat(
629+
messages,
630+
n=1,
631+
temperature=1.0,
632+
logprobs=0,
633+
max_tokens=1,
634+
)[0]
635+
response_vllm_2 = self.model_wrapper.chat(
636+
messages,
637+
n=1,
638+
temperature=0.8,
639+
logprobs=0,
640+
max_tokens=1,
641+
)[0]
642+
self.assertEqual(len(response_openai_1.tokens), len(response_vllm_1.tokens))
643+
self.assertTrue(
644+
torch.allclose(
645+
response_openai_1.logprobs,
646+
response_vllm_1.logprobs,
647+
rtol=0.1,
648+
)
649+
)
650+
self.assertTrue(
651+
torch.allclose(
652+
response_openai_2.logprobs,
653+
response_vllm_2.logprobs,
654+
rtol=0.1,
655+
)
656+
)
657+
520658

521659
class TestAsyncAPIServer(RayUnittestBaseAysnc):
522660
def setUp(self):
@@ -537,6 +675,7 @@ def setUp(self):
537675
)
538676

539677
async def test_api_async(self):
678+
await prepare_engines(self.engines, self.auxiliary_engines)
540679
await self.model_wrapper.prepare()
541680
await self.model_wrapper_no_history.prepare()
542681
openai_client = self.model_wrapper.get_openai_async_client()
@@ -758,6 +897,7 @@ async def test_api_tool_calls(self):
758897
import json
759898
import time
760899

900+
await prepare_engines(self.engines, self.auxiliary_engines)
761901
await self.model_wrapper.prepare()
762902
await self.model_wrapper_no_history.prepare()
763903
tokenizer = AutoTokenizer.from_pretrained(get_api_model_path())

tests/explorer/scheduler_test.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,9 @@ class DummyModel(InferenceModel):
175175
def sync_model(self, model_version, update_weight_args_list):
176176
return True
177177

178+
async def prepare(self):
179+
return
180+
178181
def get_model_version(self):
179182
return 0
180183

@@ -190,7 +193,7 @@ def init_process_group(
190193
) -> None:
191194
pass
192195

193-
async def get_api_server_url(self) -> Optional[str]:
196+
def get_api_server_url(self) -> Optional[str]:
194197
return None
195198

196199

@@ -214,7 +217,7 @@ def init_process_group(
214217
) -> None:
215218
pass
216219

217-
async def get_api_server_url(self) -> str:
220+
def get_api_server_url(self) -> str:
218221
return "http://localhost:12345"
219222

220223

tests/explorer/workflow_test.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -475,7 +475,7 @@ def test_workflow_repeatable(self, workflow_cls) -> None:
475475
(DummyAsyncMultiTurnWorkflow,),
476476
],
477477
)
478-
class MultiTurnWorkflowTest(unittest.TestCase):
478+
class MultiTurnWorkflowTest(unittest.IsolatedAsyncioTestCase):
479479
def setUp(self):
480480
# configure the model
481481
self.config = get_template_config()
@@ -490,7 +490,8 @@ def setUp(self):
490490
self.engines, self.auxiliary_engines = create_inference_models(self.config)
491491
self.model_wrapper = ModelWrapper(self.engines[0], engine_type="vllm", enable_history=True)
492492

493-
def test_multi_turn_workflow(self):
493+
async def test_multi_turn_workflow(self):
494+
await asyncio.gather(*[engine.prepare.remote() for engine in self.engines])
494495
task = Task(
495496
workflow=self.workflow_cls,
496497
repeat_times=3,
@@ -500,7 +501,7 @@ def test_multi_turn_workflow(self):
500501
workflow = task.to_workflow(self.model_wrapper)
501502
workflow.set_repeat_times(2, run_id_base=0)
502503
if workflow.asynchronous:
503-
answer = asyncio.run(workflow.run_async())
504+
answer = await workflow.run_async()
504505
else:
505506
answer = workflow.run()
506507
self.assertEqual(len(answer), 2)
@@ -727,7 +728,7 @@ async def test_workflow_with_openai(self):
727728
config.explorer.rollout_model.enable_history = True
728729
config.check_and_update()
729730
engines, auxiliary_engines = create_inference_models(config)
730-
731+
await asyncio.gather(*[engine.prepare.remote() for engine in engines])
731732
runner = WorkflowRunner(
732733
config,
733734
model=engines[0],

0 commit comments

Comments
 (0)