Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 131 additions & 16 deletions tests/common/vllm_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import os
import unittest

import ray
import torch
from openai import BadRequestError
from parameterized import parameterized_class
Expand Down Expand Up @@ -29,6 +29,16 @@ def print_debug(*args):
print(*args)


async def prepare_engines(engines, auxiliary_engines):
prepare_model_refs = []
for engine in engines:
prepare_model_refs.append(engine.prepare_model.remote())
for aux_engines in auxiliary_engines.values():
for engine in aux_engines:
prepare_model_refs.append(engine.prepare_model.remote())
await asyncio.gather(*prepare_model_refs)


# Qwen2.5 chat template with {% generation %} mark
CHAT_TEMPLATE = r"""
{%- if tools %}
Expand Down Expand Up @@ -127,6 +137,7 @@ def setUp(self):
async def test_generate(
self,
):
await prepare_engines(self.engines, self.auxiliary_engines)
await self.model_wrapper.prepare()
self.assertEqual(self.model_wrapper.model_path, self.config.model.model_path)
self.assertEqual(await self.model_wrapper.model_path_async, self.config.model.model_path)
Expand Down Expand Up @@ -244,6 +255,7 @@ def setUp(self):
self.tokenizer = AutoTokenizer.from_pretrained(self.config.model.model_path)

async def test_model_len(self):
await prepare_engines(self.engines, self.auxiliary_engines)
await self.model_wrapper.prepare()
messages = [
{"role": "system", "content": "You are a helpful assistant."},
Expand Down Expand Up @@ -311,6 +323,7 @@ def setUp(self):
self.model_wrapper = ModelWrapper(self.engines[0], engine_type="vllm", enable_history=True)

async def test_model_len(self):
await prepare_engines(self.engines, self.auxiliary_engines)
await self.model_wrapper.prepare()
messages = [
{"role": "user", "content": "How are you?"},
Expand Down Expand Up @@ -362,6 +375,7 @@ def setUp(self):
)

async def test_api(self):
await prepare_engines(self.engines, self.auxiliary_engines)
await self.model_wrapper.prepare()
await self.model_wrapper_no_history.prepare()
openai_client = self.model_wrapper.get_openai_client()
Expand Down Expand Up @@ -435,12 +449,42 @@ async def test_api(self):
self.assertEqual(len(self.model_wrapper_no_history.history), 0)


class DummySynchronizer:
def __init__(self):
pass
SYSTEM_PROMPT = """
You are Qwen, created by Alibaba Cloud. You are a helpful assistant. You are walking on a frozen lake.

FrozenLake Quick Guide
Goal: Reach the goal (G). Player (P) and Goal (G) must overlap.

Symbols:
_ Frozen | O Hole | G Goal | P Player

Rules:
1. Avoid falling into holes (O).
2. Frozen tiles are slippery, you may move perpendicular to your intended direction.

def do_nothing(self):
pass
Valid Action (separated by | ):
Up | Down | Left | Right

Rewards:
Fall into hole: 0
Reach goal: +1.0

You will be provided the current observation, please decide on the next Action.
You should show your thought process and then input the final action in ``` ```.
You should only output the NEXT ACTION at each interation in the ``` ```. For example, if you want to move up, you should output ```Up```.
You should plan ahead and need to achieve it in minimum number of steps.
You should be aware that frozen tiles can be slippery, but the chance is small and you should not overthink it.

Please show your thinking process and put the final action in ``` ```. In every turn, the final action MUST be one of Up, Down, Left, Right.
"""

USER_PROMPT = """Current Observation (0):
_ G _
_ _ _
P O O
You have not achieved the goal, P has not reached G yet. Please give the next action.
The maximum number of steps remaining is 10.
"""


class TestLogprobs(RayUnittestBaseAysnc):
Expand All @@ -458,14 +502,79 @@ def setUp(self):
self.engines, self.auxiliary_engines = create_inference_models(self.config)
self.model_wrapper = ModelWrapper(self.engines[0], engine_type="vllm", enable_history=True)

async def test_logprobs_api(self):
await prepare_engines(self.engines, self.auxiliary_engines)
await self.model_wrapper.prepare()
master_address, master_port = await self.engines[0].get_available_address.remote()
await self.engines[0].init_process_group.remote(
master_address,
master_port,
world_size=1,
rank_offset=0,
group_name=ROLLOUT_WEIGHT_SYNC_GROUP_NAME,
explorer_name=self.config.explorer.name,
timeout=20,
)
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": USER_PROMPT},
]
self.model_client = self.model_wrapper.get_openai_async_client()
_ = await self.model_client.chat.completions.create(
model=self.model_client.model_path,
messages=messages,
n=1,
temperature=1.0,
logprobs=True,
max_tokens=15,
)
response_1 = self.model_wrapper.extract_experience_from_history()[0]
_ = await self.model_client.chat.completions.create(
model=self.model_client.model_path,
messages=messages,
n=1,
temperature=0.8,
logprobs=True,
max_tokens=15,
)
response_2 = self.model_wrapper.extract_experience_from_history()[0]
self.assertTrue(response_1.logprobs is not None)
self.assertTrue(len(response_1.logprobs) > 0)
self.assertTrue(response_2.logprobs is not None)
self.assertTrue(len(response_2.logprobs) > 0)
logprobs_1 = self.model_wrapper.logprobs(response_1.tokens.tolist(), temperature=1.0)
logprobs_2 = self.model_wrapper.logprobs(response_1.tokens.tolist(), temperature=0.8)
logprobs_3 = self.model_wrapper.logprobs(response_2.tokens.tolist(), temperature=1.0)
logprobs_4 = self.model_wrapper.logprobs(response_2.tokens.tolist(), temperature=0.8)
self.assertEqual(logprobs_1.shape, logprobs_2.shape)
self.assertEqual(logprobs_3.shape, logprobs_4.shape)
self.assertFalse(torch.allclose(logprobs_1, logprobs_2, rtol=0.4))
self.assertFalse(torch.allclose(logprobs_3, logprobs_4, atol=0.4))
logprobs_1_prompt = logprobs_1[: response_1.prompt_length - 1]
logprobs_2_prompt = logprobs_2[: response_1.prompt_length - 1]
logprobs_3_prompt = logprobs_3[: response_2.prompt_length - 1]
logprobs_4_prompt = logprobs_4[: response_2.prompt_length - 1]
self.assertEqual(logprobs_1_prompt.shape, logprobs_2_prompt.shape)
self.assertFalse(torch.allclose(logprobs_1_prompt, logprobs_2_prompt, rtol=0.4))
self.assertFalse(torch.allclose(logprobs_3_prompt, logprobs_4_prompt, rtol=0.4))
self.assertTrue(torch.allclose(logprobs_1_prompt, logprobs_3_prompt, rtol=0.4))
self.assertTrue(torch.allclose(logprobs_2_prompt, logprobs_4_prompt, rtol=0.4))
logprobs_1_response = logprobs_1[response_1.prompt_length - 1 :]
logprobs_2_response = logprobs_2[response_1.prompt_length - 1 :]
logprobs_3_response = logprobs_3[response_2.prompt_length - 1 :]
logprobs_4_response = logprobs_4[response_2.prompt_length - 1 :]
self.assertEqual(logprobs_1_response.shape, logprobs_2_response.shape)
self.assertEqual(logprobs_3_response.shape, logprobs_4_response.shape)
self.assertEqual(logprobs_1_response.shape, logprobs_2_response.shape)
self.assertEqual(response_1.logprobs.shape, logprobs_1_response.shape)
self.assertTrue(torch.allclose(response_1.logprobs, logprobs_1_response, rtol=0.5))
self.assertFalse(torch.allclose(response_1.logprobs, logprobs_2_response, rtol=0.5))
self.assertTrue(torch.allclose(response_2.logprobs, logprobs_4_response, rtol=0.8))
self.assertFalse(torch.allclose(response_2.logprobs, logprobs_3_response, rtol=0.8))

async def test_logprobs(self):
# use init process group to apply patches
sync = (
ray.remote(DummySynchronizer)
.options(name="synchronizer", namespace=self.config.ray_namespace)
.remote()
)
await sync.__ray_ready__.remote()
await prepare_engines(self.engines, self.auxiliary_engines)
await self.model_wrapper.prepare()
master_address, master_port = await self.engines[0].get_available_address.remote()
await self.engines[0].init_process_group.remote(
Expand All @@ -478,11 +587,15 @@ async def test_logprobs(self):
timeout=20,
)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is your name?"},
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": USER_PROMPT},
]
response_1 = self.model_wrapper.chat(messages, n=1, temperature=1.0, logprobs=True)[0]
response_2 = self.model_wrapper.chat(messages, n=1, temperature=0.8, logprobs=True)[0]
response_1 = self.model_wrapper.chat(
messages, n=1, temperature=1.0, logprobs=True, max_tokens=15
)[0]
response_2 = self.model_wrapper.chat(
messages, n=1, temperature=0.8, logprobs=True, max_tokens=15
)[0]
self.assertTrue(response_1.logprobs is not None)
self.assertTrue(len(response_1.logprobs) > 0)
self.assertTrue(response_2.logprobs is not None)
Expand Down Expand Up @@ -537,6 +650,7 @@ def setUp(self):
)

async def test_api_async(self):
await prepare_engines(self.engines, self.auxiliary_engines)
await self.model_wrapper.prepare()
await self.model_wrapper_no_history.prepare()
openai_client = self.model_wrapper.get_openai_async_client()
Expand Down Expand Up @@ -758,6 +872,7 @@ async def test_api_tool_calls(self):
import json
import time

await prepare_engines(self.engines, self.auxiliary_engines)
await self.model_wrapper.prepare()
await self.model_wrapper_no_history.prepare()
tokenizer = AutoTokenizer.from_pretrained(get_api_model_path())
Expand Down
7 changes: 5 additions & 2 deletions tests/explorer/scheduler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,9 @@ class DummyModel(InferenceModel):
def sync_model(self, model_version, update_weight_args_list):
return True

async def prepare(self):
return

def get_model_version(self):
return 0

Expand All @@ -190,7 +193,7 @@ def init_process_group(
) -> None:
pass

async def get_api_server_url(self) -> Optional[str]:
def get_api_server_url(self) -> Optional[str]:
return None


Expand All @@ -214,7 +217,7 @@ def init_process_group(
) -> None:
pass

async def get_api_server_url(self) -> str:
def get_api_server_url(self) -> str:
return "http://localhost:12345"


Expand Down
6 changes: 4 additions & 2 deletions trinity/common/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,11 +152,13 @@ def create_debug_inference_model(config: Config) -> None:
model.engine_num = 1
rollout_models, auxiliary_models = create_inference_models(config)
# make sure models are started
prepare_refs = []
for m in rollout_models:
ray.get(m.run_api_server.remote())
prepare_refs.append(m.prepare.remote())
for models in auxiliary_models:
for m in models:
ray.get(m.run_api_server.remote())
prepare_refs.append(m.prepare.remote())
ray.get(prepare_refs)
logger.info(
"----------------------------------------------------\n"
"Inference models started successfully for debugging.\n"
Expand Down
6 changes: 5 additions & 1 deletion trinity/common/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ async def convert_messages_to_experience(
"""Convert a list of messages into an experience in async."""
raise NotImplementedError

async def prepare(self) -> None:
"""Prepare the model before inference."""
pass

@abstractmethod
def get_model_version(self) -> int:
"""Get the checkpoint version."""
Expand All @@ -56,7 +60,7 @@ def get_available_address(self) -> Tuple[str, int]:
port = s.getsockname()[1]
return address, port

async def get_api_server_url(self) -> Optional[str]:
def get_api_server_url(self) -> Optional[str]:
"""Get the API server URL if available."""
return None

Expand Down
63 changes: 37 additions & 26 deletions trinity/common/models/vllm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def __init__(
**config.lora_kwargs,
)
if get_vllm_version() > parse_version("0.10.0"):
engine_args.enable_log_requests = False
engine_args.enable_log_requests = True
else:
engine_args.disable_log_requests = True
if get_vllm_version() >= parse_version("0.11.0"):
Expand All @@ -131,6 +131,7 @@ def __init__(
self.api_server_host = None
self.api_server_port = None
self.api_server = None
self._prepared = False
self.async_lock = asyncio.Lock()

async def _initialize_tokenizer(self):
Expand All @@ -144,6 +145,17 @@ def _initialize_processor(self):
)
self.tokenizer = self.processor.tokenizer

async def prepare(
self,
) -> None:
"""Prepare the model for inference."""
async with self.async_lock:
if self._prepared:
return
await self._collective_rpc("apply_patches")
await self.run_api_server()
self._prepared = True

async def chat(
self, messages: List[Dict], lora_request: LoRARequest = None, **kwargs
) -> Sequence[Experience]:
Expand Down Expand Up @@ -540,41 +552,40 @@ async def run_api_server(self) -> bool:
Returns:
success (bool): Whether the API server is started successfully.
"""
async with self.async_lock:
if not self.config.enable_openai_api:
return False # Not enabled
if not self.config.enable_openai_api:
return False # Not enabled

if self.api_server_host is not None and self.api_server_port is not None:
return True # already running
if self.api_server_host is not None and self.api_server_port is not None:
return True # already running

from trinity.common.models.vllm_patch.api_patch import (
run_api_server_in_ray_actor,
)
from trinity.common.models.vllm_patch.api_patch import (
run_api_server_in_ray_actor,
)

api_server_host, api_server_port = self.get_available_address()
self.api_server = asyncio.create_task(
run_api_server_in_ray_actor(
self.async_llm,
api_server_host,
api_server_port,
self.config.model_path,
self.config.enable_auto_tool_choice,
self.config.tool_call_parser,
self.config.reasoning_parser,
)
api_server_host, api_server_port = self.get_available_address()
self.api_server = asyncio.create_task(
run_api_server_in_ray_actor(
self.async_llm,
api_server_host,
api_server_port,
self.config.model_path,
self.config.enable_auto_tool_choice,
self.config.tool_call_parser,
self.config.reasoning_parser,
)
self.api_server_host = api_server_host
self.api_server_port = api_server_port
return True
)
self.api_server_host = api_server_host
self.api_server_port = api_server_port
return True

async def get_api_server_url(self) -> Optional[str]:
def get_api_server_url(self) -> Optional[str]:
"""Get the URL of the OpenAI API server.

Returns:
api_url (str): The URL of the OpenAI API server.
"""
if not await self.run_api_server():
return None
if not self._prepared:
raise RuntimeError("Model is not prepared. Please call `prepare()` first.")
return f"http://{self.api_server_host}:{self.api_server_port}"

async def reset_prefix_cache(self) -> None:
Expand Down
Loading