Skip to content

Commit 9efaec8

Browse files
authored
Fix openai API server setup (#348)
1 parent 42d10b4 commit 9efaec8

File tree

8 files changed

+98
-72
lines changed

8 files changed

+98
-72
lines changed

tests/explorer/explorer_test.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,15 @@
1515
RayUnittestBase,
1616
RayUnittestBaseAysnc,
1717
TensorBoardParser,
18+
get_api_model_path,
1819
get_checkpoint_path,
1920
get_model_path,
2021
get_template_config,
2122
get_unittest_dataset_config,
2223
)
2324
from trinity.buffer import get_buffer_reader
2425
from trinity.cli.launcher import explore, run_stage
25-
from trinity.common.config import ExperienceBufferConfig
26+
from trinity.common.config import ExperienceBufferConfig, InferenceModelConfig
2627
from trinity.common.constants import StorageType
2728
from trinity.explorer.explorer import Explorer
2829
from trinity.manager.state_manager import StateManager
@@ -31,6 +32,7 @@
3132
class BaseExplorerCase(RayUnittestBase):
3233
def setUp(self):
3334
self.config = get_template_config()
35+
self.config.mode = "explore"
3436
self.config.buffer.total_epochs = 2
3537
self.config.buffer.batch_size = 4
3638
self.config.model.model_path = get_model_path()
@@ -67,18 +69,33 @@ def test_explorer(self):
6769
self.assertTrue("eval/eval_long/accuracy/max" in eval_metrics)
6870

6971

70-
class TestExplorerCountdownNoEval(BaseExplorerCase):
72+
class TestExplorerGSM8KRULERNoEval(BaseExplorerCase):
7173
def test_explorer(self):
72-
self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown")
74+
self.config.explorer.rollout_model.engine_num = 2
75+
self.config.explorer.auxiliary_models = [
76+
InferenceModelConfig(
77+
model_path=get_api_model_path(),
78+
tensor_parallel_size=1,
79+
engine_num=2,
80+
)
81+
]
82+
self.config.algorithm.repeat_times = 2
83+
self.config.buffer.total_steps = 2
84+
self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k_ruler")
7385
self.config.name = f"explore-no-eval-{datetime.now().strftime('%Y%m%d%H%M%S')}"
86+
self.config.algorithm.algorithm_type = "grpo"
87+
self.config.algorithm.advantage_fn = "grpo"
88+
self.config.algorithm.advantage_fn_args = {
89+
"std_threshold": 0.0001,
90+
}
7491
self.config.check_and_update()
7592
explore(self.config)
7693
parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard"))
7794
rollout_metrics = parser.metric_list("rollout")
7895
self.assertTrue(len(rollout_metrics) > 0)
7996
eval_metrics = parser.metric_list("eval")
8097
self.assertTrue(len(eval_metrics) == 0)
81-
self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 8)
98+
self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 2)
8299

83100

84101
class TestExplorerGSM8k(BaseExplorerCase):

tests/explorer/scheduler_test.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -134,10 +134,7 @@ def init_process_group(
134134
) -> None:
135135
pass
136136

137-
def has_api_server(self) -> bool:
138-
return False
139-
140-
def get_api_server_url(self) -> Optional[str]:
137+
async def get_api_server_url(self) -> Optional[str]:
141138
return None
142139

143140

@@ -161,10 +158,7 @@ def init_process_group(
161158
) -> None:
162159
pass
163160

164-
def has_api_server(self) -> bool:
165-
return True
166-
167-
def get_api_server_url(self) -> str:
161+
async def get_api_server_url(self) -> str:
168162
return "http://localhost:12345"
169163

170164

tests/tools.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,17 @@ def get_unittest_dataset_config(dataset_name: str = "countdown", split: str = "t
116116
default_workflow_type="math_workflow",
117117
default_reward_fn_type="math_reward",
118118
)
119+
elif dataset_name == "gsm8k_ruler":
120+
return TasksetConfig(
121+
name=dataset_name,
122+
path=os.path.join(os.path.dirname(__file__), "template", "data", "gsm8k"),
123+
split="train",
124+
format=FormatConfig(
125+
prompt_key="question",
126+
response_key="answer",
127+
),
128+
default_workflow_type="math_ruler_workflow",
129+
)
119130
elif dataset_name == "sft_for_gsm8k":
120131
# SFT dataset with 8 samples
121132
return ExperienceBufferConfig(

trinity/common/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -436,7 +436,7 @@ class InferenceModelConfig:
436436
# ! DO NOT SET in explorer.rollout_model, automatically set from config.model.model_path
437437
model_path: str = ""
438438

439-
engine_type: str = "vllm_async"
439+
engine_type: str = "vllm"
440440
engine_num: int = 1
441441
tensor_parallel_size: int = 1
442442
use_v1: bool = True

trinity/common/models/__init__.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -106,9 +106,7 @@ def create_inference_models(
106106
config=config.explorer.rollout_model,
107107
)
108108
)
109-
if config.explorer.rollout_model.enable_openai_api:
110-
for engine in rollout_engines:
111-
engine.run_api_server.remote()
109+
112110
if config.explorer.rollout_model.enable_history:
113111
logger.info(
114112
"Model History recording is enabled. Please periodically extract "
@@ -138,10 +136,6 @@ def create_inference_models(
138136
.remote(config=model_config)
139137
)
140138
auxiliary_engines.append(engines)
141-
# all auxiliary engines run api server
142-
for engines in auxiliary_engines:
143-
for engine in engines:
144-
engine.run_api_server.remote()
145139

146140
return rollout_engines, auxiliary_engines
147141

@@ -159,10 +153,10 @@ def create_debug_inference_model(config: Config) -> None:
159153
rollout_models, auxiliary_models = create_inference_models(config)
160154
# make sure models are started
161155
for m in rollout_models:
162-
ray.get(m.get_model_path.remote())
156+
ray.get(m.run_api_server.remote())
163157
for models in auxiliary_models:
164158
for m in models:
165-
ray.get(m.get_model_path.remote())
159+
ray.get(m.run_api_server.remote())
166160
logger.info(
167161
"----------------------------------------------------\n"
168162
"Inference models started successfully for debugging.\n"

trinity/common/models/model.py

Lines changed: 19 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,7 @@ def get_available_address(self) -> Tuple[str, int]:
5151
port = s.getsockname()[1]
5252
return address, port
5353

54-
def has_api_server(self) -> bool:
55-
"""Check if the model has an API server."""
56-
return False
57-
58-
def get_api_server_url(self) -> Optional[str]:
54+
async def get_api_server_url(self) -> Optional[str]:
5955
"""Get the API server URL if available."""
6056
return None
6157

@@ -106,26 +102,24 @@ def __init__(
106102

107103
async def prepare(self) -> None:
108104
"""Prepare the model wrapper."""
109-
if await self.model.has_api_server.remote():
110-
self.api_address = await self.model.get_api_server_url.remote()
111-
if self.api_address is None:
112-
raise RuntimeError(
113-
"Failed to connect to the API server. Please set `enable_openai_api` to `True`."
114-
)
115-
max_retries = 30
116-
interval = 2 # seconds
117-
for i in range(max_retries):
118-
try:
119-
async with httpx.AsyncClient() as client:
120-
response = await client.get(self.api_address + "/health", timeout=5)
121-
if response.status_code == 200:
122-
return
123-
except Exception as e:
124-
self.logger.info(f"API server not ready (attempt {i + 1}/{max_retries}): {e}")
125-
await asyncio.sleep(interval)
126-
raise RuntimeError(
127-
f"API server at {self.api_address} not ready after {max_retries} attempts."
128-
)
105+
self.api_address = await self.model.get_api_server_url.remote()
106+
if self.api_address is None:
107+
self.logger.info("API server is not enabled for inference model.")
108+
return
109+
max_retries = 30
110+
interval = 2 # seconds
111+
for i in range(max_retries):
112+
try:
113+
async with httpx.AsyncClient() as client:
114+
response = await client.get(self.api_address + "/health", timeout=5)
115+
if response.status_code == 200:
116+
return
117+
except Exception as e:
118+
self.logger.info(f"API server not ready (attempt {i + 1}/{max_retries}): {e}")
119+
await asyncio.sleep(interval)
120+
raise RuntimeError(
121+
f"API server at {self.api_address} not ready after {max_retries} attempts."
122+
)
129123

130124
def _record_history(self, exps: Union[Experience, List[Experience]]) -> None:
131125
"""Record experiences to history."""

trinity/common/models/vllm_model.py

Lines changed: 32 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ def __init__(
103103
self.api_server_host = None
104104
self.api_server_port = None
105105
self.api_server = None
106+
self.async_lock = asyncio.Lock()
106107

107108
async def _initialize_tokenizer(self):
108109
if self.tokenizer is None:
@@ -455,35 +456,44 @@ async def init_process_group(
455456
),
456457
)
457458

458-
async def run_api_server(self):
459-
"""Run the OpenAI API server in a Ray actor."""
460-
if not (self.api_server_host is None or self.api_server_port is None):
461-
raise RuntimeError("API server is already running.")
462-
from trinity.common.models.api.vllm_patch import run_api_server_in_ray_actor
463-
464-
self.api_server_host, self.api_server_port = self.get_available_address()
465-
self.api_server = asyncio.create_task(
466-
run_api_server_in_ray_actor(
467-
self.async_llm,
468-
self.api_server_host,
469-
self.api_server_port,
470-
self.config.model_path,
471-
self.config.enable_auto_tool_choice,
472-
self.config.tool_call_parser,
473-
self.config.reasoning_parser,
474-
)
475-
)
459+
async def run_api_server(self) -> bool:
460+
"""Run the OpenAI API server in a Ray actor.
476461
477-
def has_api_server(self) -> bool:
478-
return self.config.enable_openai_api
462+
Returns:
463+
success (bool): Whether the API server is started successfully.
464+
"""
465+
async with self.async_lock:
466+
if not self.config.enable_openai_api:
467+
return False # Not enabled
468+
469+
if self.api_server_host is not None and self.api_server_port is not None:
470+
return True # already running
471+
472+
from trinity.common.models.api.vllm_patch import run_api_server_in_ray_actor
473+
474+
api_server_host, api_server_port = self.get_available_address()
475+
self.api_server = asyncio.create_task(
476+
run_api_server_in_ray_actor(
477+
self.async_llm,
478+
api_server_host,
479+
api_server_port,
480+
self.config.model_path,
481+
self.config.enable_auto_tool_choice,
482+
self.config.tool_call_parser,
483+
self.config.reasoning_parser,
484+
)
485+
)
486+
self.api_server_host = api_server_host
487+
self.api_server_port = api_server_port
488+
return True
479489

480-
def get_api_server_url(self) -> Optional[str]:
490+
async def get_api_server_url(self) -> Optional[str]:
481491
"""Get the URL of the OpenAI API server.
482492
483493
Returns:
484494
api_url (str): The URL of the OpenAI API server.
485495
"""
486-
if not self.has_api_server():
496+
if not await self.run_api_server():
487497
return None
488498
return f"http://{self.api_server_host}:{self.api_server_port}"
489499

trinity/explorer/explorer.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,12 +150,18 @@ async def _nccl_weights_update(self):
150150
async def prepare(self) -> None:
151151
"""Preparation before running."""
152152
try:
153+
# prepare experience pipeline
153154
await self.experience_pipeline.prepare.remote()
154155
self.logger.info("Experience pipeline is ready.")
155156
# make sure all rollout models are ready
156-
model_ready_ref = [model.__ray_ready__.remote() for model in self.models]
157-
await asyncio.gather(*model_ready_ref)
158-
self.logger.info("All rollout models are ready.")
157+
run_api_ref = [model.run_api_server.remote() for model in self.models]
158+
run_api_ref.extend(
159+
model.run_api_server.remote()
160+
for models in self.auxiliary_models
161+
for model in models
162+
)
163+
await asyncio.gather(*run_api_ref)
164+
self.logger.info("All models are ready.")
159165

160166
if not self.use_nccl_sync:
161167
master_address, master_port = await self.models[0].get_available_address.remote()

0 commit comments

Comments
 (0)