Skip to content

Commit 44e35a1

Browse files
authored
Add openai client support for tinker backend (#475)
1 parent 434715d commit 44e35a1

File tree

8 files changed

+237
-38
lines changed

8 files changed

+237
-38
lines changed

scripts/context_length_test/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ Below are empirical results from running this script on various Qwen3 models acr
124124

125125
### A100 80GB
126126

127-
#### Vallina Settings (Baseline)
127+
#### Vanilla Settings (Baseline)
128128

129129
| #GPU | SP | Qwen3-0.6B | Qwen3-1.7B | Qwen3-4B | Qwen3-8B | Qwen3-14B |
130130
| ---- | -- | ---------- | ---------- | -------- | -------- | --------- |
@@ -177,7 +177,7 @@ Below are empirical results from running this script on various Qwen3 models acr
177177
### H20 96GB (Higher VRAM, Slower Bandwidth)
178178

179179

180-
#### Vallina Settings
180+
#### Vanilla Settings
181181

182182

183183
| #GPU | SP | Qwen3-0.6B | Qwen3-1.7B | Qwen3-4B | Qwen3-8B | Qwen3-14B |

tests/common/vllm_test.py

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os
33
import unittest
44

5+
import ray
56
import torch
67
from openai import BadRequestError
78
from parameterized import parameterized_class
@@ -13,12 +14,14 @@
1314
get_model_path,
1415
get_template_config,
1516
)
17+
from trinity.common.config import Config
1618
from trinity.common.models import create_inference_models
1719
from trinity.common.models.model import ModelWrapper
1820
from trinity.common.models.utils import (
1921
tokenize_and_mask_messages_default,
2022
tokenize_and_mask_messages_hf,
2123
)
24+
from trinity.manager.synchronizer import Synchronizer
2225

2326
DEBUG = False
2427

@@ -669,21 +672,32 @@ async def test_logprobs_api(self):
669672

670673

671674
class TestAsyncAPIServer(RayUnittestBaseAsync):
672-
def setUp(self):
675+
engine_type: str = "vllm"
676+
model_path: str = get_model_path()
677+
678+
async def asyncSetUp(self):
673679
self.config = get_template_config()
680+
self._update_config()
681+
await self._setup_engines()
682+
683+
def _update_config(self):
674684
self.config.mode = "explore"
675-
self.config.model.model_path = get_model_path()
685+
self.config.model.model_path = self.model_path
676686
self.config.explorer.rollout_model.engine_type = "vllm"
677687
self.config.explorer.rollout_model.engine_num = 1
678688
self.config.explorer.rollout_model.tensor_parallel_size = 1
679689
self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE
680690
self.config.explorer.rollout_model.enable_openai_api = True
681691

682692
self.config.check_and_update()
693+
694+
async def _setup_engines(self):
683695
self.engines, self.auxiliary_engines = create_inference_models(self.config)
684-
self.model_wrapper = ModelWrapper(self.engines[0], engine_type="vllm", enable_history=True)
696+
self.model_wrapper = ModelWrapper(
697+
self.engines[0], engine_type=self.engine_type, enable_history=True
698+
)
685699
self.model_wrapper_no_history = ModelWrapper(
686-
self.engines[0], engine_type="vllm", enable_history=False
700+
self.engines[0], engine_type=self.engine_type, enable_history=False
687701
)
688702

689703
async def test_api_async(self):
@@ -695,7 +709,7 @@ async def test_api_async(self):
695709
{"role": "system", "content": "You are a helpful assistant."},
696710
{"role": "user", "content": "What is your name?"},
697711
]
698-
model_id = (await openai_client.models.list()).data[0].id
712+
model_id = openai_client.model_path
699713
response = await openai_client.chat.completions.create(
700714
model=model_id, messages=messages, n=1
701715
)
@@ -713,7 +727,8 @@ async def test_api_async(self):
713727
self.assertTrue(response.choices[0].logprobs is not None)
714728
self.assertEqual(0, len(response.choices[0].logprobs.content[2].top_logprobs))
715729
# here we check the 3rd token logprob, because the first two tokens (`<think>`,`\n` usually have zero logprob)
716-
self.assertTrue(response.choices[0].logprobs.content[2].logprob < 0)
730+
if "Instruct" not in self.model_path:
731+
self.assertTrue(response.choices[0].logprobs.content[2].logprob < 0)
717732
self.assertTrue(hasattr(response, "prompt_token_ids"))
718733
self.assertTrue(len(response.prompt_token_ids) > 0)
719734
self.assertTrue(hasattr(response.choices[0], "token_ids"))
@@ -765,6 +780,32 @@ async def test_api_async(self):
765780
self.assertEqual(len(self.model_wrapper_no_history.history), 0)
766781

767782

783+
@unittest.skipIf("TINKER_API_KEY" not in os.environ, "TINKER_API_KEY is not set")
784+
class TestTinkerAsyncAPIServer(TestAsyncAPIServer):
785+
engine_type: str = "tinker"
786+
model_path: str = "Qwen/Qwen3-4B-Instruct-2507"
787+
# llama model in Tinker does not support chat template
788+
789+
def _update_config(self):
790+
self.config.model.tinker.enable = True
791+
self.config.algorithm.algorithm_type = "grpo"
792+
super()._update_config()
793+
794+
async def _setup_engines(self):
795+
@ray.remote
796+
class FakeTrainer:
797+
def __init__(self, config: Config):
798+
self.config = config
799+
self.synchronizer = Synchronizer.get_actor(config)
800+
801+
fake_trainer = FakeTrainer.remote(self.config)
802+
await fake_trainer.__ray_ready__.remote()
803+
await super()._setup_engines()
804+
805+
async def test_api_async(self):
806+
await super().test_api_async()
807+
808+
768809
class TestTokenizer(unittest.TestCase):
769810
def test_action_mask(self):
770811
messages = [

tests/trainer/trainer_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1402,7 +1402,7 @@ def tearDown(self):
14021402

14031403

14041404
class TestTinkerTrainer(BaseTrainerCase):
1405-
@unittest.skip("Require tinker API key")
1405+
@unittest.skipIf("TINKER_API_KEY" not in os.environ, "TINKER_API_KEY is not set")
14061406
def test_trainer(self):
14071407
"""Test GSM8K on tinker."""
14081408
# test both mode

trinity/common/config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1218,6 +1218,11 @@ def _check_tinker(self) -> None:
12181218
self.explorer.rollout_model.engine_type = "tinker"
12191219
logger.warning("Rollout model engine type is set to `tinker`.")
12201220

1221+
for aux_model_config in self.explorer.auxiliary_models:
1222+
if aux_model_config.engine_type != "tinker":
1223+
aux_model_config.engine_type = "tinker"
1224+
logger.warning("Auxiliary model engine type is set to `tinker`.")
1225+
12211226
if self.trainer.trainer_type != "tinker":
12221227
self.trainer.trainer_type = "tinker"
12231228
logger.warning("Trainer type is set to `tinker`.")

trinity/common/models/__init__.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -71,16 +71,18 @@ def create_inference_models(
7171
for i in range(engine_num)
7272
]
7373
auxiliary_engines = [
74-
ray.remote(engine_cls)
75-
.options(
76-
name=f"{config.explorer.name}_auxiliary_model_{i}_{j}",
77-
namespace=namespace,
78-
)
79-
.remote(
80-
config=config.explorer.auxiliary_models[i],
81-
)
74+
[
75+
ray.remote(engine_cls)
76+
.options(
77+
name=f"{config.explorer.name}_auxiliary_model_{i}_{j}",
78+
namespace=namespace,
79+
)
80+
.remote(
81+
config=config.explorer.auxiliary_models[i],
82+
)
83+
for j in range(model_config.engine_num)
84+
]
8285
for i, model_config in enumerate(config.explorer.auxiliary_models)
83-
for j in range(model_config.engine_num)
8486
]
8587
return rollout_engines, auxiliary_engines
8688
else:

trinity/common/models/model.py

Lines changed: 80 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,18 @@ def get_api_server_url(self) -> Optional[str]:
7171
"""Get the API server URL if available."""
7272
return None
7373

74+
def get_api_key(self) -> str:
75+
"""Get the API key."""
76+
return "EMPTY"
77+
7478
def get_model_config(self) -> InferenceModelConfig:
7579
"""Get the model configuration."""
7680
return self.config
7781

82+
def get_model_path(self) -> Optional[str]:
83+
"""Get the model path"""
84+
return self.config.model_path
85+
7886

7987
def _history_recorder(func):
8088
"""Decorator to record history of the model calls."""
@@ -118,10 +126,11 @@ def __init__(
118126
engine_type.startswith("vllm") or engine_type == "tinker"
119127
), "Only vLLM and tinker model is supported for now."
120128
self.model = model
129+
self.engine_type = engine_type
121130
self.config: InferenceModelConfig = None # init during prepare
122131
self._model_name: str = None
123-
self._model_path: str = None
124132
self.api_address: str = None
133+
self._api_key: str = None
125134
self.openai_client: openai.OpenAI = None
126135
self.openai_async_client: openai.AsyncOpenAI = None
127136
self.logger = get_logger(__name__)
@@ -138,7 +147,7 @@ async def prepare(self) -> None:
138147
"""Prepare the model wrapper."""
139148
self.config = await self.model.get_model_config.remote()
140149
self._model_name = self.config.name
141-
self._model_path = self.config.model_path
150+
self._api_key = await self.model.get_api_key.remote()
142151
self._generate_kwargs = {
143152
"temperature": self.config.temperature,
144153
"top_p": self.config.top_p,
@@ -152,6 +161,8 @@ async def prepare(self) -> None:
152161
if self.api_address is None:
153162
self.logger.info("API server is not enabled for inference model.")
154163
return
164+
if self.engine_type == "tinker":
165+
return
155166
max_retries = 30
156167
interval = 2 # seconds
157168
for i in range(max_retries):
@@ -285,6 +296,11 @@ async def convert_messages_to_experience_async(
285296
messages, tools=tools, temperature=temperature
286297
)
287298

299+
@property
300+
def api_key(self) -> str:
301+
"""Get the API key."""
302+
return self._api_key
303+
288304
@property
289305
def model_version(self) -> int:
290306
"""Get the version of the model."""
@@ -297,8 +313,23 @@ async def model_version_async(self) -> int:
297313

298314
@property
299315
def model_path(self) -> str:
300-
"""Get the model path."""
301-
return self._model_path
316+
"""
317+
Returns the path to the model files based on the current engine type.
318+
319+
- For 'vllm' engine: returns the model path from the configuration (`config.model_path`)
320+
- For 'tinker' engine: returns the path to the most recent sampler weights
321+
"""
322+
return ray.get(self.model.get_model_path.remote())
323+
324+
@property
325+
async def model_path_async(self) -> str:
326+
"""
327+
Returns the path to the model files based on the current engine type.
328+
329+
- For 'vllm' engine: returns the model path from the configuration (`config.model_path`)
330+
- For 'tinker' engine: returns the path to the most recent sampler weights
331+
"""
332+
return await self.model.get_model_path.remote()
302333

303334
@property
304335
def model_name(self) -> Optional[str]:
@@ -332,16 +363,36 @@ def get_openai_client(self) -> openai.OpenAI:
332363
openai.OpenAI: The openai client. And `model_path` is added to the client which refers to the model path.
333364
"""
334365
if self.openai_client is not None:
366+
setattr(self.openai_client, "model_path", self.model_path)
335367
return self.openai_client
336368
if not self.api_address:
337369
raise ValueError(
338370
"API server is not enabled for this model. OpenAI client is unavailable."
339371
)
340372
self.openai_client = openai.OpenAI(
341373
base_url=f"{self.api_address}/v1",
342-
api_key="EMPTY",
374+
api_key=self._api_key,
343375
)
344-
if self.enable_history:
376+
if self.engine_type == "tinker":
377+
# ! TODO: because tinker's OpenAI API interface is in beta,
378+
# we need to use original API in thinker instead.
379+
def chat_completions(*args, **kwargs):
380+
messages = kwargs.pop("messages")
381+
chat_response = ray.get(
382+
self.model.chat.remote(
383+
messages=messages,
384+
with_chat_completion=True,
385+
return_token_ids=self.enable_history,
386+
**kwargs,
387+
)
388+
)
389+
response = chat_response.pop()
390+
if self.enable_history:
391+
self.history.extend(chat_response)
392+
return response
393+
394+
self.openai_client.chat.completions.create = chat_completions
395+
elif self.enable_history:
345396
# add a decorator to the openai client to record history
346397

347398
ori_create = self.openai_client.chat.completions.create
@@ -359,7 +410,7 @@ def record_chat_completions(*args, **kwargs):
359410
return response
360411

361412
self.openai_client.chat.completions.create = record_chat_completions
362-
setattr(self.openai_client, "model_path", self.openai_client.models.list().data[0].id)
413+
setattr(self.openai_client, "model_path", self.model_path)
363414
return self.openai_client
364415

365416
def get_openai_async_client(self) -> openai.AsyncOpenAI:
@@ -369,6 +420,7 @@ def get_openai_async_client(self) -> openai.AsyncOpenAI:
369420
openai.AsyncOpenAI: The async openai client. And `model_path` is added to the client which refers to the model path.
370421
"""
371422
if self.openai_async_client is not None:
423+
setattr(self.openai_async_client, "model_path", self.model_path)
372424
return self.openai_async_client
373425
if not self.api_address:
374426
raise ValueError(
@@ -377,9 +429,27 @@ def get_openai_async_client(self) -> openai.AsyncOpenAI:
377429
# first make sure that we have the sync openai client
378430
self.openai_async_client = openai.AsyncOpenAI(
379431
base_url=f"{self.api_address}/v1",
380-
api_key="EMPTY",
432+
api_key=self._api_key,
381433
)
382-
if self.enable_history:
434+
435+
if self.engine_type == "tinker":
436+
# ! TODO: because tinker's OpenAI API interface is in beta,
437+
# we need to use original API in thinker instead.
438+
async def chat_completions(*args, **kwargs):
439+
messages = kwargs.pop("messages")
440+
chat_response = await self.model.chat.remote(
441+
messages=messages,
442+
with_chat_completion=True,
443+
return_token_ids=self.enable_history,
444+
**kwargs,
445+
)
446+
response = chat_response.pop()
447+
if self.enable_history:
448+
self.history.extend(chat_response)
449+
return response
450+
451+
self.openai_async_client.chat.completions.create = chat_completions
452+
elif self.enable_history:
383453
# add a decorator to the openai client to record history
384454

385455
ori_create = self.openai_async_client.chat.completions.create
@@ -400,8 +470,7 @@ async def record_chat_completions(*args, **kwargs):
400470

401471
self.openai_async_client.chat.completions.create = record_chat_completions
402472
# get model_path from the sync openai client to avoid async call here
403-
openai_client = self.get_openai_client()
404-
setattr(self.openai_async_client, "model_path", openai_client.models.list().data[0].id)
473+
setattr(self.openai_async_client, "model_path", self.model_path)
405474
return self.openai_async_client
406475

407476
async def get_current_load(self) -> int:

0 commit comments

Comments
 (0)