Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 169 additions & 29 deletions tests/common/vllm_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import os
import unittest

import ray
import torch
from openai import BadRequestError
from parameterized import parameterized_class
Expand All @@ -13,7 +13,6 @@
get_model_path,
get_template_config,
)
from trinity.common.constants import ROLLOUT_WEIGHT_SYNC_GROUP_NAME
from trinity.common.models import create_inference_models
from trinity.common.models.model import ModelWrapper
from trinity.common.models.utils import (
Expand All @@ -29,6 +28,16 @@ def print_debug(*args):
print(*args)


async def prepare_engines(engines, auxiliary_engines):
prepare_model_refs = []
for engine in engines:
prepare_model_refs.append(engine.prepare.remote())
for engines in auxiliary_engines:
for engine in engines:
prepare_model_refs.append(engine.prepare.remote())
await asyncio.gather(*prepare_model_refs)


# Qwen2.5 chat template with {% generation %} mark
CHAT_TEMPLATE = r"""
{%- if tools %}
Expand Down Expand Up @@ -127,6 +136,7 @@ def setUp(self):
async def test_generate(
self,
):
await prepare_engines(self.engines, self.auxiliary_engines)
await self.model_wrapper.prepare()
self.assertEqual(self.model_wrapper.model_path, self.config.model.model_path)
self.assertEqual(await self.model_wrapper.model_path_async, self.config.model.model_path)
Expand Down Expand Up @@ -244,6 +254,7 @@ def setUp(self):
self.tokenizer = AutoTokenizer.from_pretrained(self.config.model.model_path)

async def test_model_len(self):
await prepare_engines(self.engines, self.auxiliary_engines)
await self.model_wrapper.prepare()
messages = [
{"role": "system", "content": "You are a helpful assistant."},
Expand Down Expand Up @@ -311,6 +322,7 @@ def setUp(self):
self.model_wrapper = ModelWrapper(self.engines[0], engine_type="vllm", enable_history=True)

async def test_model_len(self):
await prepare_engines(self.engines, self.auxiliary_engines)
await self.model_wrapper.prepare()
messages = [
{"role": "user", "content": "How are you?"},
Expand Down Expand Up @@ -362,6 +374,7 @@ def setUp(self):
)

async def test_api(self):
await prepare_engines(self.engines, self.auxiliary_engines)
await self.model_wrapper.prepare()
await self.model_wrapper_no_history.prepare()
openai_client = self.model_wrapper.get_openai_client()
Expand Down Expand Up @@ -435,12 +448,42 @@ async def test_api(self):
self.assertEqual(len(self.model_wrapper_no_history.history), 0)


class DummySynchronizer:
def __init__(self):
pass
SYSTEM_PROMPT = """
You are Qwen, created by Alibaba Cloud. You are a helpful assistant. You are walking on a frozen lake.

FrozenLake Quick Guide
Goal: Reach the goal (G). Player (P) and Goal (G) must overlap.

Symbols:
_ Frozen | O Hole | G Goal | P Player

Rules:
1. Avoid falling into holes (O).
2. Frozen tiles are slippery, you may move perpendicular to your intended direction.

Valid Action (separated by | ):
Up | Down | Left | Right

Rewards:
Fall into hole: 0
Reach goal: +1.0

You will be provided the current observation, please decide on the next Action.
You should show your thought process and then input the final action in ``` ```.
You should only output the NEXT ACTION at each interation in the ``` ```. For example, if you want to move up, you should output ```Up```.
You should plan ahead and need to achieve it in minimum number of steps.
You should be aware that frozen tiles can be slippery, but the chance is small and you should not overthink it.

Please show your thinking process and put the final action in ``` ```. In every turn, the final action MUST be one of Up, Down, Left, Right.
"""

def do_nothing(self):
pass
USER_PROMPT = """Current Observation (0):
_ G _
_ _ _
P O O
You have not achieved the goal, P has not reached G yet. Please give the next action.
The maximum number of steps remaining is 10.
"""


class TestLogprobs(RayUnittestBaseAysnc):
Expand All @@ -458,31 +501,76 @@ def setUp(self):
self.engines, self.auxiliary_engines = create_inference_models(self.config)
self.model_wrapper = ModelWrapper(self.engines[0], engine_type="vllm", enable_history=True)

async def test_logprobs(self):
# use init process group to apply patches
sync = (
ray.remote(DummySynchronizer)
.options(name="synchronizer", namespace=self.config.ray_namespace)
.remote()
)
await sync.__ray_ready__.remote()
async def test_logprobs_api(self):
await prepare_engines(self.engines, self.auxiliary_engines)
await self.model_wrapper.prepare()
master_address, master_port = await self.engines[0].get_available_address.remote()
await self.engines[0].init_process_group.remote(
master_address,
master_port,
world_size=1,
rank_offset=0,
group_name=ROLLOUT_WEIGHT_SYNC_GROUP_NAME,
explorer_name=self.config.explorer.name,
timeout=20,
)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is your name?"},
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": USER_PROMPT},
]
response_1 = self.model_wrapper.chat(messages, n=1, temperature=1.0, logprobs=True)[0]
response_2 = self.model_wrapper.chat(messages, n=1, temperature=0.8, logprobs=True)[0]

# Test openai api logprobs with different temperature

self.model_client = self.model_wrapper.get_openai_async_client()
_ = await self.model_client.chat.completions.create(
model=self.model_client.model_path,
messages=messages,
n=1,
temperature=1.0,
logprobs=True,
max_tokens=15,
)
response_1 = self.model_wrapper.extract_experience_from_history()[0]
_ = await self.model_client.chat.completions.create(
model=self.model_client.model_path,
messages=messages,
n=1,
temperature=0.8,
logprobs=True,
max_tokens=15,
)
response_2 = self.model_wrapper.extract_experience_from_history()[0]
self.assertTrue(response_1.logprobs is not None)
self.assertTrue(len(response_1.logprobs) > 0)
self.assertTrue(response_2.logprobs is not None)
self.assertTrue(len(response_2.logprobs) > 0)
logprobs_1 = self.model_wrapper.logprobs(response_1.tokens.tolist(), temperature=1.0)
logprobs_2 = self.model_wrapper.logprobs(response_1.tokens.tolist(), temperature=0.8)
logprobs_3 = self.model_wrapper.logprobs(response_2.tokens.tolist(), temperature=1.0)
logprobs_4 = self.model_wrapper.logprobs(response_2.tokens.tolist(), temperature=0.8)
self.assertEqual(logprobs_1.shape, logprobs_2.shape)
self.assertEqual(logprobs_3.shape, logprobs_4.shape)
self.assertFalse(torch.allclose(logprobs_1, logprobs_2, rtol=0.4))
self.assertFalse(torch.allclose(logprobs_3, logprobs_4, atol=0.4))
logprobs_1_prompt = logprobs_1[: response_1.prompt_length - 1]
logprobs_2_prompt = logprobs_2[: response_1.prompt_length - 1]
logprobs_3_prompt = logprobs_3[: response_2.prompt_length - 1]
logprobs_4_prompt = logprobs_4[: response_2.prompt_length - 1]
self.assertEqual(logprobs_1_prompt.shape, logprobs_2_prompt.shape)
self.assertFalse(torch.allclose(logprobs_1_prompt, logprobs_2_prompt, rtol=0.4))
self.assertFalse(torch.allclose(logprobs_3_prompt, logprobs_4_prompt, rtol=0.4))
self.assertTrue(torch.allclose(logprobs_1_prompt, logprobs_3_prompt, rtol=0.4))
self.assertTrue(torch.allclose(logprobs_2_prompt, logprobs_4_prompt, rtol=0.4))
logprobs_1_response = logprobs_1[response_1.prompt_length - 1 :]
logprobs_2_response = logprobs_2[response_1.prompt_length - 1 :]
logprobs_3_response = logprobs_3[response_2.prompt_length - 1 :]
logprobs_4_response = logprobs_4[response_2.prompt_length - 1 :]
self.assertEqual(logprobs_1_response.shape, logprobs_2_response.shape)
self.assertEqual(logprobs_3_response.shape, logprobs_4_response.shape)
self.assertEqual(logprobs_1_response.shape, logprobs_2_response.shape)
self.assertEqual(response_1.logprobs.shape, logprobs_1_response.shape)
self.assertTrue(torch.allclose(response_1.logprobs, logprobs_1_response, rtol=0.5))
self.assertFalse(torch.allclose(response_1.logprobs, logprobs_2_response, rtol=0.5))
self.assertTrue(torch.allclose(response_2.logprobs, logprobs_4_response, rtol=0.8))
self.assertFalse(torch.allclose(response_2.logprobs, logprobs_3_response, rtol=0.8))

# test vllm engine logprobs with different temperature
response_1 = self.model_wrapper.chat(
messages, n=1, temperature=1.0, logprobs=True, max_tokens=15
)[0]
response_2 = self.model_wrapper.chat(
messages, n=1, temperature=0.8, logprobs=True, max_tokens=15
)[0]
self.assertTrue(response_1.logprobs is not None)
self.assertTrue(len(response_1.logprobs) > 0)
self.assertTrue(response_2.logprobs is not None)
Expand Down Expand Up @@ -517,6 +605,56 @@ async def test_logprobs(self):
self.assertTrue(torch.allclose(response_2.logprobs, logprobs_4_response, rtol=0.8))
self.assertFalse(torch.allclose(response_2.logprobs, logprobs_3_response, rtol=0.8))

# test openai api and vllm engine logprobs consistency
await self.model_wrapper.clean_workflow_state()
_ = await self.model_client.chat.completions.create(
model=self.model_client.model_path,
messages=messages,
n=1,
temperature=1.0,
logprobs=0,
max_tokens=1,
)
response_openai_1 = self.model_wrapper.extract_experience_from_history()[0]
_ = await self.model_client.chat.completions.create(
model=self.model_client.model_path,
messages=messages,
n=1,
temperature=0.8,
logprobs=0,
max_tokens=1,
)
response_openai_2 = self.model_wrapper.extract_experience_from_history()[0]
response_vllm_1 = self.model_wrapper.chat(
messages,
n=1,
temperature=1.0,
logprobs=0,
max_tokens=1,
)[0]
response_vllm_2 = self.model_wrapper.chat(
messages,
n=1,
temperature=0.8,
logprobs=0,
max_tokens=1,
)[0]
self.assertEqual(len(response_openai_1.tokens), len(response_vllm_1.tokens))
self.assertTrue(
torch.allclose(
response_openai_1.logprobs,
response_vllm_1.logprobs,
rtol=0.1,
)
)
self.assertTrue(
torch.allclose(
response_openai_2.logprobs,
response_vllm_2.logprobs,
rtol=0.1,
)
)


class TestAsyncAPIServer(RayUnittestBaseAysnc):
def setUp(self):
Expand All @@ -537,6 +675,7 @@ def setUp(self):
)

async def test_api_async(self):
await prepare_engines(self.engines, self.auxiliary_engines)
await self.model_wrapper.prepare()
await self.model_wrapper_no_history.prepare()
openai_client = self.model_wrapper.get_openai_async_client()
Expand Down Expand Up @@ -758,6 +897,7 @@ async def test_api_tool_calls(self):
import json
import time

await prepare_engines(self.engines, self.auxiliary_engines)
await self.model_wrapper.prepare()
await self.model_wrapper_no_history.prepare()
tokenizer = AutoTokenizer.from_pretrained(get_api_model_path())
Expand Down
7 changes: 5 additions & 2 deletions tests/explorer/scheduler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,9 @@ class DummyModel(InferenceModel):
def sync_model(self, model_version, update_weight_args_list):
return True

async def prepare(self):
return

def get_model_version(self):
return 0

Expand All @@ -190,7 +193,7 @@ def init_process_group(
) -> None:
pass

async def get_api_server_url(self) -> Optional[str]:
def get_api_server_url(self) -> Optional[str]:
return None


Expand All @@ -214,7 +217,7 @@ def init_process_group(
) -> None:
pass

async def get_api_server_url(self) -> str:
def get_api_server_url(self) -> str:
return "http://localhost:12345"


Expand Down
9 changes: 5 additions & 4 deletions tests/explorer/workflow_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,7 @@ def test_workflow_repeatable(self, workflow_cls) -> None:
(DummyAsyncMultiTurnWorkflow,),
],
)
class MultiTurnWorkflowTest(unittest.TestCase):
class MultiTurnWorkflowTest(unittest.IsolatedAsyncioTestCase):
def setUp(self):
# configure the model
self.config = get_template_config()
Expand All @@ -490,7 +490,8 @@ def setUp(self):
self.engines, self.auxiliary_engines = create_inference_models(self.config)
self.model_wrapper = ModelWrapper(self.engines[0], engine_type="vllm", enable_history=True)

def test_multi_turn_workflow(self):
async def test_multi_turn_workflow(self):
await asyncio.gather(*[engine.prepare.remote() for engine in self.engines])
task = Task(
workflow=self.workflow_cls,
repeat_times=3,
Expand All @@ -500,7 +501,7 @@ def test_multi_turn_workflow(self):
workflow = task.to_workflow(self.model_wrapper)
workflow.set_repeat_times(2, run_id_base=0)
if workflow.asynchronous:
answer = asyncio.run(workflow.run_async())
answer = await workflow.run_async()
else:
answer = workflow.run()
self.assertEqual(len(answer), 2)
Expand Down Expand Up @@ -727,7 +728,7 @@ async def test_workflow_with_openai(self):
config.explorer.rollout_model.enable_history = True
config.check_and_update()
engines, auxiliary_engines = create_inference_models(config)

await asyncio.gather(*[engine.prepare.remote() for engine in engines])
runner = WorkflowRunner(
config,
model=engines[0],
Expand Down
Loading