Skip to content

Commit e76403a

Browse files
authored
Add model_path to auxiliary_models. (#67)
1 parent e721eab commit e76403a

File tree

5 files changed

+27
-13
lines changed

5 files changed

+27
-13
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ requires-python = ">=3.10"
2323
dependencies = [
2424
"verl==0.3.0.post1",
2525
"ray[default]>=2.45.0",
26-
"vllm>=0.8.5",
26+
"vllm==0.8.5.post1",
2727
"tensordict==0.6.2",
2828
"wandb",
2929
"omegaconf",

tests/explorer/runner_pool_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import os
33
import time
44
import unittest
5-
from typing import List
5+
from typing import List, Tuple
66

77
import ray
88
import torch
@@ -87,8 +87,8 @@ def init_process_group(
8787
def has_api_server(self) -> bool:
8888
return True
8989

90-
def api_server_ready(self) -> str:
91-
return "http://localhosts:12345"
90+
def api_server_ready(self) -> Tuple[str, str]:
91+
return "http://localhosts:12345", "placeholder"
9292

9393

9494
class RunnerPoolTest(unittest.TestCase):

trinity/common/models/model.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,16 +103,21 @@ def get_ckp_version(self) -> int:
103103
return ray.get(self.model.get_ckp_version.remote())
104104

105105
def get_openai_client(self) -> openai.OpenAI:
106+
"""Get the openai client.
107+
108+
Returns:
109+
openai.OpenAI: The openai client. And `model_path` is added to the client which refers to the model path.
110+
"""
106111
if self.openai_client is not None:
107112
return self.openai_client
108113
if not ray.get(self.model.has_api_server.remote()):
109114
raise ValueError(
110115
"OpenAI API server is not running on current model."
111116
"Please set `enable_openai_api` to `True`."
112117
)
113-
api_address = None
118+
api_address, model_path = None, None
114119
while True:
115-
api_address = ray.get(self.model.api_server_ready.remote())
120+
api_address, model_path = ray.get(self.model.api_server_ready.remote())
116121
if api_address is not None:
117122
break
118123
else:
@@ -127,4 +132,5 @@ def get_openai_client(self) -> openai.OpenAI:
127132
base_url=api_address,
128133
api_key="EMPTY",
129134
)
135+
setattr(self.openai_client, "model_path", model_path) # TODO: may be removed
130136
return self.openai_client

trinity/common/models/vllm_async_model.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import os
77
import re
8-
from typing import Any, Dict, List, Optional
8+
from typing import Any, Dict, List, Optional, Tuple, Union
99

1010
import aiohttp
1111
import torch
@@ -319,26 +319,30 @@ async def run_api_server(self):
319319
async def has_api_server(self) -> bool:
320320
return self.config.enable_openai_api
321321

322-
async def api_server_ready(self) -> Optional[str]:
322+
async def api_server_ready(self) -> Tuple[Union[str, None], Union[str, None]]:
323323
"""Check if the OpenAI API server is ready.
324324
325325
Returns:
326-
str: The URL of the OpenAI API server.
326+
api_url (str): The URL of the OpenAI API server.
327+
model_path (str): The path of the model.
327328
"""
328329
if not await self.has_api_server():
329-
return None
330+
return None, None
330331
try:
331332
async with aiohttp.ClientSession() as session:
332333
async with session.get(
333334
f"http://{self.api_server_host}:{self.api_server_port}/health"
334335
) as response:
335336
if response.status == 200:
336-
return f"http://{self.api_server_host}:{self.api_server_port}/v1"
337+
return (
338+
f"http://{self.api_server_host}:{self.api_server_port}/v1",
339+
self.config.model_path,
340+
)
337341
else:
338-
return None
342+
return None, None
339343
except Exception as e:
340344
self.logger.error(e)
341-
return None
345+
return None, None
342346

343347
async def reset_prefix_cache(self) -> None:
344348
await self.async_llm.reset_prefix_cache()

trinity/common/workflows/workflow.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ def to_workflow(
4141
4242
Args:
4343
model (ModelWrapper): The rollout model for the workflow.
44+
auxiliary_models (List[openai.OpenAI]): The auxiliary models for the workflow.
45+
46+
Note:
47+
`model_path` attribute is added to the `auxiliary_models` for use within the workflow.
4448
4549
Returns:
4650
Workflow: The generated workflow object.

0 commit comments

Comments
 (0)