Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ requires-python = ">=3.10"
dependencies = [
"verl==0.3.0.post1",
"ray[default]>=2.45.0",
"vllm>=0.8.5",
"vllm==0.8.5.post1",
"tensordict==0.6.2",
"wandb",
"omegaconf",
Expand Down
6 changes: 3 additions & 3 deletions tests/explorer/runner_pool_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import time
import unittest
from typing import List
from typing import List, Tuple

import ray
import torch
Expand Down Expand Up @@ -87,8 +87,8 @@ def init_process_group(
def has_api_server(self) -> bool:
return True

def api_server_ready(self) -> str:
return "http://localhosts:12345"
def api_server_ready(self) -> Tuple[str, str]:
return "http://localhosts:12345", "placeholder"


class RunnerPoolTest(unittest.TestCase):
Expand Down
10 changes: 8 additions & 2 deletions trinity/common/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,16 +103,21 @@ def get_ckp_version(self) -> int:
return ray.get(self.model.get_ckp_version.remote())

def get_openai_client(self) -> openai.OpenAI:
"""Get the openai client.

Returns:
openai.OpenAI: The openai client. And `model_path` is added to the client which refers to the model path.
"""
if self.openai_client is not None:
return self.openai_client
if not ray.get(self.model.has_api_server.remote()):
raise ValueError(
"OpenAI API server is not running on current model."
"Please set `enable_openai_api` to `True`."
)
api_address = None
api_address, model_path = None, None
while True:
api_address = ray.get(self.model.api_server_ready.remote())
api_address, model_path = ray.get(self.model.api_server_ready.remote())
if api_address is not None:
break
else:
Expand All @@ -127,4 +132,5 @@ def get_openai_client(self) -> openai.OpenAI:
base_url=api_address,
api_key="EMPTY",
)
setattr(self.openai_client, "model_path", model_path) # TODO: may be removed
return self.openai_client
18 changes: 11 additions & 7 deletions trinity/common/models/vllm_async_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import os
import re
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Tuple, Union

import aiohttp
import torch
Expand Down Expand Up @@ -319,26 +319,30 @@ async def run_api_server(self):
async def has_api_server(self) -> bool:
return self.config.enable_openai_api

async def api_server_ready(self) -> Optional[str]:
async def api_server_ready(self) -> Tuple[Union[str, None], Union[str, None]]:
"""Check if the OpenAI API server is ready.

Returns:
str: The URL of the OpenAI API server.
api_url (str): The URL of the OpenAI API server.
model_path (str): The path of the model.
"""
if not await self.has_api_server():
return None
return None, None
try:
async with aiohttp.ClientSession() as session:
async with session.get(
f"http://{self.api_server_host}:{self.api_server_port}/health"
) as response:
if response.status == 200:
return f"http://{self.api_server_host}:{self.api_server_port}/v1"
return (
f"http://{self.api_server_host}:{self.api_server_port}/v1",
self.config.model_path,
)
else:
return None
return None, None
except Exception as e:
self.logger.error(e)
return None
return None, None

async def reset_prefix_cache(self) -> None:
await self.async_llm.reset_prefix_cache()
Expand Down
4 changes: 4 additions & 0 deletions trinity/common/workflows/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ def to_workflow(

Args:
model (ModelWrapper): The rollout model for the workflow.
auxiliary_models (List[openai.OpenAI]): The auxiliary models for the workflow.

Note:
`model_path` attribute is added to the `auxiliary_models` for use within the workflow.

Returns:
Workflow: The generated workflow object.
Expand Down