Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions scripts/context_length_test/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ Below are empirical results from running this script on various Qwen3 models acr

### A100 80GB

#### Vallina Settings (Baseline)
#### Vanilla Settings (Baseline)

| #GPU | SP | Qwen3-0.6B | Qwen3-1.7B | Qwen3-4B | Qwen3-8B | Qwen3-14B |
| ---- | -- | ---------- | ---------- | -------- | -------- | --------- |
Expand Down Expand Up @@ -177,7 +177,7 @@ Below are empirical results from running this script on various Qwen3 models acr
### H20 96GB (Higher VRAM, Slower Bandwidth)


#### Vallina Settings
#### Vanilla Settings


| #GPU | SP | Qwen3-0.6B | Qwen3-1.7B | Qwen3-4B | Qwen3-8B | Qwen3-14B |
Expand Down
57 changes: 51 additions & 6 deletions tests/common/vllm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,21 +669,32 @@ async def test_logprobs_api(self):


class TestAsyncAPIServer(RayUnittestBaseAsync):
def setUp(self):
engine_type: str = "vllm"
model_path: str = get_model_path()

async def asyncSetUp(self):
self.config = get_template_config()
self._update_config()
await self._setup_engines()

def _update_config(self):
self.config.mode = "explore"
self.config.model.model_path = get_model_path()
self.config.model.model_path = self.model_path
self.config.explorer.rollout_model.engine_type = "vllm"
self.config.explorer.rollout_model.engine_num = 1
self.config.explorer.rollout_model.tensor_parallel_size = 1
self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE
self.config.explorer.rollout_model.enable_openai_api = True

self.config.check_and_update()

async def _setup_engines(self):
self.engines, self.auxiliary_engines = create_inference_models(self.config)
self.model_wrapper = ModelWrapper(self.engines[0], engine_type="vllm", enable_history=True)
self.model_wrapper = ModelWrapper(
self.engines[0], engine_type=self.engine_type, enable_history=True
)
self.model_wrapper_no_history = ModelWrapper(
self.engines[0], engine_type="vllm", enable_history=False
self.engines[0], engine_type=self.engine_type, enable_history=False
)

async def test_api_async(self):
Expand All @@ -695,7 +706,7 @@ async def test_api_async(self):
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is your name?"},
]
model_id = (await openai_client.models.list()).data[0].id
model_id = openai_client.model_path
response = await openai_client.chat.completions.create(
model=model_id, messages=messages, n=1
)
Expand All @@ -713,7 +724,8 @@ async def test_api_async(self):
self.assertTrue(response.choices[0].logprobs is not None)
self.assertEqual(0, len(response.choices[0].logprobs.content[2].top_logprobs))
# here we check the 3rd token logprob, because the first two tokens (`<think>`,`\n` usually have zero logprob)
self.assertTrue(response.choices[0].logprobs.content[2].logprob < 0)
if "Instruct" not in self.model_path:
self.assertTrue(response.choices[0].logprobs.content[2].logprob < 0)
self.assertTrue(hasattr(response, "prompt_token_ids"))
self.assertTrue(len(response.prompt_token_ids) > 0)
self.assertTrue(hasattr(response.choices[0], "token_ids"))
Expand Down Expand Up @@ -765,6 +777,39 @@ async def test_api_async(self):
self.assertEqual(len(self.model_wrapper_no_history.history), 0)


class TestTinkerAsyncAPIServer(TestAsyncAPIServer):
engine_type: str = "tinker"
model_path: str = "Qwen/Qwen3-4B-Instruct-2507"
# llama model in Tinker does not support chat template

def _update_config(self):
self.config.model.tinker.enable = True
self.config.algorithm.algorithm_type = "grpo"
super()._update_config()
from pprint import pprint

pprint(self.config)

async def _setup_engines(self):
import ray

from trinity.common.config import Config
from trinity.manager.synchronizer import Synchronizer

@ray.remote
class FakeTrainer:
def __init__(self, config: Config):
self.config = config
self.synchronizer = Synchronizer.get_actor(config)

fake_trainer = FakeTrainer.remote(self.config)
await fake_trainer.__ray_ready__.remote()
await super()._setup_engines()

async def test_api_async(self):
await super().test_api_async()


class TestTokenizer(unittest.TestCase):
def test_action_mask(self):
messages = [
Expand Down
5 changes: 5 additions & 0 deletions trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1218,6 +1218,11 @@ def _check_tinker(self) -> None:
self.explorer.rollout_model.engine_type = "tinker"
logger.warning("Rollout model engine type is set to `tinker`.")

for aux_model_config in self.explorer.auxiliary_models:
if aux_model_config.engine_type != "tinker":
aux_model_config.engine_type = "tinker"
logger.warning("Auxiliary model engine type is set to `tinker`.")

if self.trainer.trainer_type != "tinker":
self.trainer.trainer_type = "tinker"
logger.warning("Trainer type is set to `tinker`.")
Expand Down
20 changes: 11 additions & 9 deletions trinity/common/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,16 +71,18 @@ def create_inference_models(
for i in range(engine_num)
]
auxiliary_engines = [
ray.remote(engine_cls)
.options(
name=f"{config.explorer.name}_auxiliary_model_{i}_{j}",
namespace=namespace,
)
.remote(
config=config.explorer.auxiliary_models[i],
)
[
ray.remote(engine_cls)
.options(
name=f"{config.explorer.name}_auxiliary_model_{i}_{j}",
namespace=namespace,
)
.remote(
config=config.explorer.auxiliary_models[i],
)
for j in range(model_config.engine_num)
]
for i, model_config in enumerate(config.explorer.auxiliary_models)
for j in range(model_config.engine_num)
]
return rollout_engines, auxiliary_engines
else:
Expand Down
84 changes: 74 additions & 10 deletions trinity/common/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,18 @@ def get_api_server_url(self) -> Optional[str]:
"""Get the API server URL if available."""
return None

def get_api_key(self) -> str:
"""Get the API key."""
return "EMPTY"

def get_model_config(self) -> InferenceModelConfig:
"""Get the model configuration."""
return self.config

def get_model_path(self) -> Optional[str]:
"""Get the model path"""
return self.config.model_path


def _history_recorder(func):
"""Decorator to record history of the model calls."""
Expand Down Expand Up @@ -118,10 +126,11 @@ def __init__(
engine_type.startswith("vllm") or engine_type == "tinker"
), "Only vLLM and tinker model is supported for now."
self.model = model
self.engine_type = engine_type
self.config: InferenceModelConfig = None # init during prepare
self._model_name: str = None
self._model_path: str = None
self.api_address: str = None
self._api_key: str = None
self.openai_client: openai.OpenAI = None
self.openai_async_client: openai.AsyncOpenAI = None
self.logger = get_logger(__name__)
Expand All @@ -138,7 +147,7 @@ async def prepare(self) -> None:
"""Prepare the model wrapper."""
self.config = await self.model.get_model_config.remote()
self._model_name = self.config.name
self._model_path = self.config.model_path
self._api_key = await self.model.get_api_key.remote()
self._generate_kwargs = {
"temperature": self.config.temperature,
"top_p": self.config.top_p,
Expand All @@ -152,6 +161,8 @@ async def prepare(self) -> None:
if self.api_address is None:
self.logger.info("API server is not enabled for inference model.")
return
if self.engine_type == "tinker":
return
max_retries = 30
interval = 2 # seconds
for i in range(max_retries):
Expand Down Expand Up @@ -285,6 +296,11 @@ async def convert_messages_to_experience_async(
messages, tools=tools, temperature=temperature
)

@property
def api_key(self) -> str:
"""Get the API key."""
return self._api_key

@property
def model_version(self) -> int:
"""Get the version of the model."""
Expand All @@ -298,7 +314,12 @@ async def model_version_async(self) -> int:
@property
def model_path(self) -> str:
"""Get the model path."""
return self._model_path
return ray.get(self.model.get_model_path.remote())

@property
async def model_path_async(self) -> str:
"""Get the model path."""
return await self.model.get_model_path.remote()

@property
def model_name(self) -> Optional[str]:
Expand Down Expand Up @@ -332,16 +353,38 @@ def get_openai_client(self) -> openai.OpenAI:
openai.OpenAI: The openai client. And `model_path` is added to the client which refers to the model path.
"""
if self.openai_client is not None:
setattr(self.openai_client, "model_path", self.model_path)
return self.openai_client
if not self.api_address:
raise ValueError(
"API server is not enabled for this model. OpenAI client is unavailable."
)
self.openai_client = openai.OpenAI(
base_url=f"{self.api_address}/v1",
api_key="EMPTY",
api_key=self._api_key,
)
if self.enable_history:
if self.engine_type == "tinker":
# ! TODO: because tinker's OpenAI API interface is in beta,
# we need to use original API in thinker instead.
ori_create = self.openai_async_client.chat.completions.create

async def chat_completions(*args, **kwargs):
messages = kwargs.pop("messages")
chat_response = ray.get(
self.model.chat.remote(
messages=messages,
with_chat_completion=True,
return_token_ids=self.enable_history,
**kwargs,
)
)
response = chat_response.pop()
if self.enable_history:
self.history.extend(chat_response)
return response

self.openai_async_client.chat.completions.create = chat_completions
elif self.enable_history:
# add a decorator to the openai client to record history

ori_create = self.openai_client.chat.completions.create
Expand All @@ -359,7 +402,7 @@ def record_chat_completions(*args, **kwargs):
return response

self.openai_client.chat.completions.create = record_chat_completions
setattr(self.openai_client, "model_path", self.openai_client.models.list().data[0].id)
setattr(self.openai_client, "model_path", self.model_path)
return self.openai_client

def get_openai_async_client(self) -> openai.AsyncOpenAI:
Expand All @@ -369,6 +412,7 @@ def get_openai_async_client(self) -> openai.AsyncOpenAI:
openai.AsyncOpenAI: The async openai client. And `model_path` is added to the client which refers to the model path.
"""
if self.openai_async_client is not None:
setattr(self.openai_async_client, "model_path", self.model_path)
return self.openai_async_client
if not self.api_address:
raise ValueError(
Expand All @@ -377,9 +421,29 @@ def get_openai_async_client(self) -> openai.AsyncOpenAI:
# first make sure that we have the sync openai client
self.openai_async_client = openai.AsyncOpenAI(
base_url=f"{self.api_address}/v1",
api_key="EMPTY",
api_key=self._api_key,
)
if self.enable_history:

if self.engine_type == "tinker":
# ! TODO: because tinker's OpenAI API interface is in beta,
# we need to use original API in thinker instead.
ori_create = self.openai_async_client.chat.completions.create

async def chat_completions(*args, **kwargs):
messages = kwargs.pop("messages")
chat_response = await self.model.chat.remote(
messages=messages,
with_chat_completion=True,
return_token_ids=self.enable_history,
**kwargs,
)
response = chat_response.pop()
if self.enable_history:
self.history.extend(chat_response)
return response

self.openai_async_client.chat.completions.create = chat_completions
elif self.enable_history:
# add a decorator to the openai client to record history

ori_create = self.openai_async_client.chat.completions.create
Expand All @@ -400,8 +464,8 @@ async def record_chat_completions(*args, **kwargs):

self.openai_async_client.chat.completions.create = record_chat_completions
# get model_path from the sync openai client to avoid async call here
openai_client = self.get_openai_client()
setattr(self.openai_async_client, "model_path", openai_client.models.list().data[0].id)
# openai_client = self.get_openai_client()
setattr(self.openai_async_client, "model_path", self.model_path)
return self.openai_async_client

async def get_current_load(self) -> int:
Expand Down
Loading