Skip to content

Commit d65ed46

Browse files
committed
Add openai client support for tinker backend
1 parent 15a6d2a commit d65ed46

File tree

7 files changed

+237
-39
lines changed

7 files changed

+237
-39
lines changed

scripts/context_length_test/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ Below are empirical results from running this script on various Qwen3 models acr
124124

125125
### A100 80GB
126126

127-
#### Vallina Settings (Baseline)
127+
#### Vanilla Settings (Baseline)
128128

129129
| #GPU | SP | Qwen3-0.6B | Qwen3-1.7B | Qwen3-4B | Qwen3-8B | Qwen3-14B |
130130
| ---- | -- | ---------- | ---------- | -------- | -------- | --------- |
@@ -177,7 +177,7 @@ Below are empirical results from running this script on various Qwen3 models acr
177177
### H20 96GB (Higher VRAM, Slower Bandwidth)
178178

179179

180-
#### Vallina Settings
180+
#### Vanilla Settings
181181

182182

183183
| #GPU | SP | Qwen3-0.6B | Qwen3-1.7B | Qwen3-4B | Qwen3-8B | Qwen3-14B |

tests/common/vllm_test.py

Lines changed: 51 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -669,21 +669,32 @@ async def test_logprobs_api(self):
669669

670670

671671
class TestAsyncAPIServer(RayUnittestBaseAsync):
672-
def setUp(self):
672+
engine_type: str = "vllm"
673+
model_path: str = get_model_path()
674+
675+
async def asyncSetUp(self):
673676
self.config = get_template_config()
677+
self._update_config()
678+
await self._setup_engines()
679+
680+
def _update_config(self):
674681
self.config.mode = "explore"
675-
self.config.model.model_path = get_model_path()
682+
self.config.model.model_path = self.model_path
676683
self.config.explorer.rollout_model.engine_type = "vllm"
677684
self.config.explorer.rollout_model.engine_num = 1
678685
self.config.explorer.rollout_model.tensor_parallel_size = 1
679686
self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE
680687
self.config.explorer.rollout_model.enable_openai_api = True
681688

682689
self.config.check_and_update()
690+
691+
async def _setup_engines(self):
683692
self.engines, self.auxiliary_engines = create_inference_models(self.config)
684-
self.model_wrapper = ModelWrapper(self.engines[0], engine_type="vllm", enable_history=True)
693+
self.model_wrapper = ModelWrapper(
694+
self.engines[0], engine_type=self.engine_type, enable_history=True
695+
)
685696
self.model_wrapper_no_history = ModelWrapper(
686-
self.engines[0], engine_type="vllm", enable_history=False
697+
self.engines[0], engine_type=self.engine_type, enable_history=False
687698
)
688699

689700
async def test_api_async(self):
@@ -695,7 +706,7 @@ async def test_api_async(self):
695706
{"role": "system", "content": "You are a helpful assistant."},
696707
{"role": "user", "content": "What is your name?"},
697708
]
698-
model_id = (await openai_client.models.list()).data[0].id
709+
model_id = openai_client.model_path
699710
response = await openai_client.chat.completions.create(
700711
model=model_id, messages=messages, n=1
701712
)
@@ -713,7 +724,8 @@ async def test_api_async(self):
713724
self.assertTrue(response.choices[0].logprobs is not None)
714725
self.assertEqual(0, len(response.choices[0].logprobs.content[2].top_logprobs))
715726
# here we check the 3rd token logprob, because the first two tokens (`<think>`,`\n` usually have zero logprob)
716-
self.assertTrue(response.choices[0].logprobs.content[2].logprob < 0)
727+
if "Instruct" not in self.model_path:
728+
self.assertTrue(response.choices[0].logprobs.content[2].logprob < 0)
717729
self.assertTrue(hasattr(response, "prompt_token_ids"))
718730
self.assertTrue(len(response.prompt_token_ids) > 0)
719731
self.assertTrue(hasattr(response.choices[0], "token_ids"))
@@ -765,6 +777,39 @@ async def test_api_async(self):
765777
self.assertEqual(len(self.model_wrapper_no_history.history), 0)
766778

767779

780+
class TestTinkerAsyncAPIServer(TestAsyncAPIServer):
781+
engine_type: str = "tinker"
782+
model_path: str = "Qwen/Qwen3-4B-Instruct-2507"
783+
# llama model in Tinker does not support chat template
784+
785+
def _update_config(self):
786+
self.config.model.tinker.enable = True
787+
self.config.algorithm.algorithm_type = "grpo"
788+
super()._update_config()
789+
from pprint import pprint
790+
791+
pprint(self.config)
792+
793+
async def _setup_engines(self):
794+
import ray
795+
796+
from trinity.common.config import Config
797+
from trinity.manager.synchronizer import Synchronizer
798+
799+
@ray.remote
800+
class FakeTrainer:
801+
def __init__(self, config: Config):
802+
self.config = config
803+
self.synchronizer = Synchronizer.get_actor(config)
804+
805+
fake_trainer = FakeTrainer.remote(self.config)
806+
await fake_trainer.__ray_ready__.remote()
807+
await super()._setup_engines()
808+
809+
async def test_api_async(self):
810+
await super().test_api_async()
811+
812+
768813
class TestTokenizer(unittest.TestCase):
769814
def test_action_mask(self):
770815
messages = [

trinity/common/config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1218,6 +1218,11 @@ def _check_tinker(self) -> None:
12181218
self.explorer.rollout_model.engine_type = "tinker"
12191219
logger.warning("Rollout model engine type is set to `tinker`.")
12201220

1221+
for aux_model_config in self.explorer.auxiliary_models:
1222+
if aux_model_config.engine_type != "tinker":
1223+
aux_model_config.engine_type = "tinker"
1224+
logger.warning("Auxiliary model engine type is set to `tinker`.")
1225+
12211226
if self.trainer.trainer_type != "tinker":
12221227
self.trainer.trainer_type = "tinker"
12231228
logger.warning("Trainer type is set to `tinker`.")

trinity/common/models/__init__.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -71,16 +71,18 @@ def create_inference_models(
7171
for i in range(engine_num)
7272
]
7373
auxiliary_engines = [
74-
ray.remote(engine_cls)
75-
.options(
76-
name=f"{config.explorer.name}_auxiliary_model_{i}_{j}",
77-
namespace=namespace,
78-
)
79-
.remote(
80-
config=config.explorer.auxiliary_models[i],
81-
)
74+
[
75+
ray.remote(engine_cls)
76+
.options(
77+
name=f"{config.explorer.name}_auxiliary_model_{i}_{j}",
78+
namespace=namespace,
79+
)
80+
.remote(
81+
config=config.explorer.auxiliary_models[i],
82+
)
83+
for j in range(model_config.engine_num)
84+
]
8285
for i, model_config in enumerate(config.explorer.auxiliary_models)
83-
for j in range(model_config.engine_num)
8486
]
8587
return rollout_engines, auxiliary_engines
8688
else:

trinity/common/models/model.py

Lines changed: 74 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,18 @@ def get_api_server_url(self) -> Optional[str]:
7171
"""Get the API server URL if available."""
7272
return None
7373

74+
def get_api_key(self) -> str:
75+
"""Get the API key."""
76+
return "EMPTY"
77+
7478
def get_model_config(self) -> InferenceModelConfig:
7579
"""Get the model configuration."""
7680
return self.config
7781

82+
def get_model_path(self) -> Optional[str]:
83+
"""Get the model path"""
84+
return self.config.model_path
85+
7886

7987
def _history_recorder(func):
8088
"""Decorator to record history of the model calls."""
@@ -118,10 +126,11 @@ def __init__(
118126
engine_type.startswith("vllm") or engine_type == "tinker"
119127
), "Only vLLM and tinker model is supported for now."
120128
self.model = model
129+
self.engine_type = engine_type
121130
self.config: InferenceModelConfig = None # init during prepare
122131
self._model_name: str = None
123-
self._model_path: str = None
124132
self.api_address: str = None
133+
self._api_key: str = None
125134
self.openai_client: openai.OpenAI = None
126135
self.openai_async_client: openai.AsyncOpenAI = None
127136
self.logger = get_logger(__name__)
@@ -138,7 +147,7 @@ async def prepare(self) -> None:
138147
"""Prepare the model wrapper."""
139148
self.config = await self.model.get_model_config.remote()
140149
self._model_name = self.config.name
141-
self._model_path = self.config.model_path
150+
self._api_key = await self.model.get_api_key.remote()
142151
self._generate_kwargs = {
143152
"temperature": self.config.temperature,
144153
"top_p": self.config.top_p,
@@ -152,6 +161,8 @@ async def prepare(self) -> None:
152161
if self.api_address is None:
153162
self.logger.info("API server is not enabled for inference model.")
154163
return
164+
if self.engine_type == "tinker":
165+
return
155166
max_retries = 30
156167
interval = 2 # seconds
157168
for i in range(max_retries):
@@ -285,6 +296,11 @@ async def convert_messages_to_experience_async(
285296
messages, tools=tools, temperature=temperature
286297
)
287298

299+
@property
300+
def api_key(self) -> str:
301+
"""Get the API key."""
302+
return self._api_key
303+
288304
@property
289305
def model_version(self) -> int:
290306
"""Get the version of the model."""
@@ -298,7 +314,12 @@ async def model_version_async(self) -> int:
298314
@property
299315
def model_path(self) -> str:
300316
"""Get the model path."""
301-
return self._model_path
317+
return ray.get(self.model.get_model_path.remote())
318+
319+
@property
320+
async def model_path_async(self) -> str:
321+
"""Get the model path."""
322+
return await self.model.get_model_path.remote()
302323

303324
@property
304325
def model_name(self) -> Optional[str]:
@@ -332,16 +353,38 @@ def get_openai_client(self) -> openai.OpenAI:
332353
openai.OpenAI: The openai client. And `model_path` is added to the client which refers to the model path.
333354
"""
334355
if self.openai_client is not None:
356+
setattr(self.openai_client, "model_path", self.model_path)
335357
return self.openai_client
336358
if not self.api_address:
337359
raise ValueError(
338360
"API server is not enabled for this model. OpenAI client is unavailable."
339361
)
340362
self.openai_client = openai.OpenAI(
341363
base_url=f"{self.api_address}/v1",
342-
api_key="EMPTY",
364+
api_key=self._api_key,
343365
)
344-
if self.enable_history:
366+
if self.engine_type == "tinker":
367+
# ! TODO: because tinker's OpenAI API interface is in beta,
368+
# we need to use original API in thinker instead.
369+
ori_create = self.openai_async_client.chat.completions.create
370+
371+
async def chat_completions(*args, **kwargs):
372+
messages = kwargs.pop("messages")
373+
chat_response = ray.get(
374+
self.model.chat.remote(
375+
messages=messages,
376+
with_chat_completion=True,
377+
return_token_ids=self.enable_history,
378+
**kwargs,
379+
)
380+
)
381+
response = chat_response.pop()
382+
if self.enable_history:
383+
self.history.extend(chat_response)
384+
return response
385+
386+
self.openai_async_client.chat.completions.create = chat_completions
387+
elif self.enable_history:
345388
# add a decorator to the openai client to record history
346389

347390
ori_create = self.openai_client.chat.completions.create
@@ -359,7 +402,7 @@ def record_chat_completions(*args, **kwargs):
359402
return response
360403

361404
self.openai_client.chat.completions.create = record_chat_completions
362-
setattr(self.openai_client, "model_path", self.openai_client.models.list().data[0].id)
405+
setattr(self.openai_client, "model_path", self.model_path)
363406
return self.openai_client
364407

365408
def get_openai_async_client(self) -> openai.AsyncOpenAI:
@@ -369,6 +412,7 @@ def get_openai_async_client(self) -> openai.AsyncOpenAI:
369412
openai.AsyncOpenAI: The async openai client. And `model_path` is added to the client which refers to the model path.
370413
"""
371414
if self.openai_async_client is not None:
415+
setattr(self.openai_async_client, "model_path", self.model_path)
372416
return self.openai_async_client
373417
if not self.api_address:
374418
raise ValueError(
@@ -377,9 +421,29 @@ def get_openai_async_client(self) -> openai.AsyncOpenAI:
377421
# first make sure that we have the sync openai client
378422
self.openai_async_client = openai.AsyncOpenAI(
379423
base_url=f"{self.api_address}/v1",
380-
api_key="EMPTY",
424+
api_key=self._api_key,
381425
)
382-
if self.enable_history:
426+
427+
if self.engine_type == "tinker":
428+
# ! TODO: because tinker's OpenAI API interface is in beta,
429+
# we need to use original API in thinker instead.
430+
ori_create = self.openai_async_client.chat.completions.create
431+
432+
async def chat_completions(*args, **kwargs):
433+
messages = kwargs.pop("messages")
434+
chat_response = await self.model.chat.remote(
435+
messages=messages,
436+
with_chat_completion=True,
437+
return_token_ids=self.enable_history,
438+
**kwargs,
439+
)
440+
response = chat_response.pop()
441+
if self.enable_history:
442+
self.history.extend(chat_response)
443+
return response
444+
445+
self.openai_async_client.chat.completions.create = chat_completions
446+
elif self.enable_history:
383447
# add a decorator to the openai client to record history
384448

385449
ori_create = self.openai_async_client.chat.completions.create
@@ -400,8 +464,8 @@ async def record_chat_completions(*args, **kwargs):
400464

401465
self.openai_async_client.chat.completions.create = record_chat_completions
402466
# get model_path from the sync openai client to avoid async call here
403-
openai_client = self.get_openai_client()
404-
setattr(self.openai_async_client, "model_path", openai_client.models.list().data[0].id)
467+
# openai_client = self.get_openai_client()
468+
setattr(self.openai_async_client, "model_path", self.model_path)
405469
return self.openai_async_client
406470

407471
async def get_current_load(self) -> int:

0 commit comments

Comments
 (0)