diff --git a/pyproject.toml b/pyproject.toml index dcf86f8349..022c9a8ffe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ requires-python = ">=3.10" dependencies = [ "verl==0.3.0.post1", "ray[default]>=2.45.0", - "vllm>=0.8.5", + "vllm==0.8.5.post1", "tensordict==0.6.2", "wandb", "omegaconf", diff --git a/tests/explorer/runner_pool_test.py b/tests/explorer/runner_pool_test.py index 036339e747..285961dc37 100644 --- a/tests/explorer/runner_pool_test.py +++ b/tests/explorer/runner_pool_test.py @@ -2,7 +2,7 @@ import os import time import unittest -from typing import List +from typing import List, Tuple import ray import torch @@ -87,8 +87,8 @@ def init_process_group( def has_api_server(self) -> bool: return True - def api_server_ready(self) -> str: - return "http://localhosts:12345" + def api_server_ready(self) -> Tuple[str, str]: + return "http://localhosts:12345", "placeholder" class RunnerPoolTest(unittest.TestCase): diff --git a/trinity/common/models/model.py b/trinity/common/models/model.py index b5104f2cc7..cb15b1ae3d 100644 --- a/trinity/common/models/model.py +++ b/trinity/common/models/model.py @@ -103,6 +103,11 @@ def get_ckp_version(self) -> int: return ray.get(self.model.get_ckp_version.remote()) def get_openai_client(self) -> openai.OpenAI: + """Get the openai client. + + Returns: + openai.OpenAI: The openai client. And `model_path` is added to the client which refers to the model path. + """ if self.openai_client is not None: return self.openai_client if not ray.get(self.model.has_api_server.remote()): @@ -110,9 +115,9 @@ def get_openai_client(self) -> openai.OpenAI: "OpenAI API server is not running on current model." "Please set `enable_openai_api` to `True`." ) - api_address = None + api_address, model_path = None, None while True: - api_address = ray.get(self.model.api_server_ready.remote()) + api_address, model_path = ray.get(self.model.api_server_ready.remote()) if api_address is not None: break else: @@ -127,4 +132,5 @@ def get_openai_client(self) -> openai.OpenAI: base_url=api_address, api_key="EMPTY", ) + setattr(self.openai_client, "model_path", model_path) # TODO: may be removed return self.openai_client diff --git a/trinity/common/models/vllm_async_model.py b/trinity/common/models/vllm_async_model.py index 02ea52ec58..27faa4c44a 100644 --- a/trinity/common/models/vllm_async_model.py +++ b/trinity/common/models/vllm_async_model.py @@ -5,7 +5,7 @@ import os import re -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple, Union import aiohttp import torch @@ -319,26 +319,30 @@ async def run_api_server(self): async def has_api_server(self) -> bool: return self.config.enable_openai_api - async def api_server_ready(self) -> Optional[str]: + async def api_server_ready(self) -> Tuple[Union[str, None], Union[str, None]]: """Check if the OpenAI API server is ready. Returns: - str: The URL of the OpenAI API server. + api_url (str): The URL of the OpenAI API server. + model_path (str): The path of the model. """ if not await self.has_api_server(): - return None + return None, None try: async with aiohttp.ClientSession() as session: async with session.get( f"http://{self.api_server_host}:{self.api_server_port}/health" ) as response: if response.status == 200: - return f"http://{self.api_server_host}:{self.api_server_port}/v1" + return ( + f"http://{self.api_server_host}:{self.api_server_port}/v1", + self.config.model_path, + ) else: - return None + return None, None except Exception as e: self.logger.error(e) - return None + return None, None async def reset_prefix_cache(self) -> None: await self.async_llm.reset_prefix_cache() diff --git a/trinity/common/workflows/workflow.py b/trinity/common/workflows/workflow.py index 9786bd6b77..169ad63279 100644 --- a/trinity/common/workflows/workflow.py +++ b/trinity/common/workflows/workflow.py @@ -41,6 +41,10 @@ def to_workflow( Args: model (ModelWrapper): The rollout model for the workflow. + auxiliary_models (List[openai.OpenAI]): The auxiliary models for the workflow. + + Note: + `model_path` attribute is added to the `auxiliary_models` for use within the workflow. Returns: Workflow: The generated workflow object.