Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 3 additions & 9 deletions trinity/common/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,7 @@ def create_inference_models(
config=config.explorer.rollout_model,
)
)
if config.explorer.rollout_model.enable_openai_api:
for engine in rollout_engines:
engine.run_api_server.remote()

if config.explorer.rollout_model.enable_history:
logger.info(
"Model History recording is enabled. Please periodically extract "
Expand Down Expand Up @@ -138,10 +136,6 @@ def create_inference_models(
.remote(config=model_config)
)
auxiliary_engines.append(engines)
# all auxiliary engines run api server
for engines in auxiliary_engines:
for engine in engines:
engine.run_api_server.remote()

return rollout_engines, auxiliary_engines

Expand All @@ -159,10 +153,10 @@ def create_debug_inference_model(config: Config) -> None:
rollout_models, auxiliary_models = create_inference_models(config)
# make sure models are started
for m in rollout_models:
ray.get(m.get_model_path.remote())
ray.get(m.run_api_server.remote())
for models in auxiliary_models:
for m in models:
ray.get(m.get_model_path.remote())
ray.get(m.run_api_server.remote())
logger.info(
"----------------------------------------------------\n"
"Inference models started successfully for debugging.\n"
Expand Down
7 changes: 5 additions & 2 deletions trinity/common/models/vllm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,8 +457,11 @@ async def init_process_group(

async def run_api_server(self):
"""Run the OpenAI API server in a Ray actor."""
if not (self.api_server_host is None or self.api_server_port is None):
raise RuntimeError("API server is already running.")
if not self.config.enable_openai_api or not (
self.api_server_host is None or self.api_server_port is None
):
return # no need to run or already running

from trinity.common.models.api.vllm_patch import run_api_server_in_ray_actor

self.api_server_host, self.api_server_port = self.get_available_address()
Expand Down
12 changes: 9 additions & 3 deletions trinity/explorer/explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,12 +150,18 @@ async def _nccl_weights_update(self):
async def prepare(self) -> None:
"""Preparation before running."""
try:
# prepare experience pipeline
await self.experience_pipeline.prepare.remote()
self.logger.info("Experience pipeline is ready.")
# make sure all rollout models are ready
model_ready_ref = [model.__ray_ready__.remote() for model in self.models]
await asyncio.gather(*model_ready_ref)
self.logger.info("All rollout models are ready.")
run_api_ref = []
for model in self.models:
run_api_ref.append(model.run_api_server.remote())
for models in self.auxiliary_models:
for model in models:
run_api_ref.append(model.run_api_server.remote())
await asyncio.gather(*run_api_ref)
self.logger.info("All models are ready.")

if not self.use_nccl_sync:
master_address, master_port = await self.models[0].get_available_address.remote()
Expand Down