Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ requires-python = ">=3.10"
dependencies = [
"verl==0.3.0.post1",
"ray[default]>=2.45.0",
"vllm>=0.8.5",
"vllm==0.8.5.post1",
"tensordict==0.6.2",
"wandb",
"omegaconf",
Expand Down
6 changes: 3 additions & 3 deletions tests/explorer/runner_pool_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import time
import unittest
from typing import List
from typing import List, Tuple

import ray
import torch
Expand Down Expand Up @@ -87,8 +87,8 @@ def init_process_group(
def has_api_server(self) -> bool:
return True

def api_server_ready(self) -> str:
return "http://localhosts:12345"
def api_server_ready(self) -> Tuple[str, str]:
return "http://localhosts:12345", "placeholder"


class RunnerPoolTest(unittest.TestCase):
Expand Down
10 changes: 8 additions & 2 deletions trinity/common/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,16 +103,21 @@ def get_ckp_version(self) -> int:
return ray.get(self.model.get_ckp_version.remote())

def get_openai_client(self) -> openai.OpenAI:
"""Get the openai client.

Returns:
openai.OpenAI: The openai client. And `model_path` is added to the client which refers to the model path.
"""
if self.openai_client is not None:
return self.openai_client
if not ray.get(self.model.has_api_server.remote()):
raise ValueError(
"OpenAI API server is not running on current model."
"Please set `enable_openai_api` to `True`."
)
api_address = None
api_address, model_path = None, None
while True:
api_address = ray.get(self.model.api_server_ready.remote())
api_address, model_path = ray.get(self.model.api_server_ready.remote())
if api_address is not None:
break
else:
Expand All @@ -127,4 +132,5 @@ def get_openai_client(self) -> openai.OpenAI:
base_url=api_address,
api_key="EMPTY",
)
setattr(self.openai_client, "model_path", model_path)
return self.openai_client
18 changes: 11 additions & 7 deletions trinity/common/models/vllm_async_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import os
import re
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Tuple, Union

import aiohttp
import torch
Expand Down Expand Up @@ -319,26 +319,30 @@ async def run_api_server(self):
async def has_api_server(self) -> bool:
return self.config.enable_openai_api

async def api_server_ready(self) -> Optional[str]:
async def api_server_ready(self) -> Tuple[Union[str, None], Union[str, None]]:
"""Check if the OpenAI API server is ready.

Returns:
str: The URL of the OpenAI API server.
api_url (str): The URL of the OpenAI API server.
model_path (str): The path of the model.
"""
if not await self.has_api_server():
return None
return None, None
try:
async with aiohttp.ClientSession() as session:
async with session.get(
f"http://{self.api_server_host}:{self.api_server_port}/health"
) as response:
if response.status == 200:
return f"http://{self.api_server_host}:{self.api_server_port}/v1"
return (
f"http://{self.api_server_host}:{self.api_server_port}/v1",
self.config.model_path,
)
else:
return None
return None, None
except Exception as e:
self.logger.error(e)
return None
return None, None

async def reset_prefix_cache(self) -> None:
await self.async_llm.reset_prefix_cache()
Expand Down
4 changes: 4 additions & 0 deletions trinity/common/workflows/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ def to_workflow(
Args:
model (ModelWrapper): The rollout model for the workflow.
auxiliary_models (List[openai.OpenAI]): The auxiliary models for the workflow.
Note:
`model_path` attribute is added to the `auxiliary_models` for use within the workflow.
Returns:
Workflow: The generated workflow object.
Expand Down