From d535a6c79eff9e8c622081332cfaa4c29f5f78e9 Mon Sep 17 00:00:00 2001 From: pxc Date: Wed, 29 Oct 2025 15:39:39 +0800 Subject: [PATCH 1/4] fix api server start --- trinity/common/models/__init__.py | 12 +++--------- trinity/common/models/vllm_model.py | 7 +++++-- trinity/explorer/explorer.py | 12 +++++++++--- 3 files changed, 17 insertions(+), 14 deletions(-) diff --git a/trinity/common/models/__init__.py b/trinity/common/models/__init__.py index 9b12b48340..ffdc98c070 100644 --- a/trinity/common/models/__init__.py +++ b/trinity/common/models/__init__.py @@ -106,9 +106,7 @@ def create_inference_models( config=config.explorer.rollout_model, ) ) - if config.explorer.rollout_model.enable_openai_api: - for engine in rollout_engines: - engine.run_api_server.remote() + if config.explorer.rollout_model.enable_history: logger.info( "Model History recording is enabled. Please periodically extract " @@ -138,10 +136,6 @@ def create_inference_models( .remote(config=model_config) ) auxiliary_engines.append(engines) - # all auxiliary engines run api server - for engines in auxiliary_engines: - for engine in engines: - engine.run_api_server.remote() return rollout_engines, auxiliary_engines @@ -159,10 +153,10 @@ def create_debug_inference_model(config: Config) -> None: rollout_models, auxiliary_models = create_inference_models(config) # make sure models are started for m in rollout_models: - ray.get(m.get_model_path.remote()) + ray.get(m.run_api_server.remote()) for models in auxiliary_models: for m in models: - ray.get(m.get_model_path.remote()) + ray.get(m.run_api_server.remote()) logger.info( "----------------------------------------------------\n" "Inference models started successfully for debugging.\n" diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index ee8e30e293..d5174a41cb 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -457,8 +457,11 @@ async def init_process_group( async def run_api_server(self): """Run the OpenAI API server in a Ray actor.""" - if not (self.api_server_host is None or self.api_server_port is None): - raise RuntimeError("API server is already running.") + if not self.config.enable_openai_api or not ( + self.api_server_host is None or self.api_server_port is None + ): + return # no need to run or already running + from trinity.common.models.api.vllm_patch import run_api_server_in_ray_actor self.api_server_host, self.api_server_port = self.get_available_address() diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index bf748e4877..e7f79999d7 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -150,12 +150,18 @@ async def _nccl_weights_update(self): async def prepare(self) -> None: """Preparation before running.""" try: + # prepare experience pipeline await self.experience_pipeline.prepare.remote() self.logger.info("Experience pipeline is ready.") # make sure all rollout models are ready - model_ready_ref = [model.__ray_ready__.remote() for model in self.models] - await asyncio.gather(*model_ready_ref) - self.logger.info("All rollout models are ready.") + run_api_ref = [] + for model in self.models: + run_api_ref.append(model.run_api_server.remote()) + for models in self.auxiliary_models: + for model in models: + run_api_ref.append(model.run_api_server.remote()) + await asyncio.gather(*run_api_ref) + self.logger.info("All models are ready.") if not self.use_nccl_sync: master_address, master_port = await self.models[0].get_available_address.remote() From 150ed00da5ff7a26b0dabad5934aa37a010a4f6a Mon Sep 17 00:00:00 2001 From: pxc Date: Wed, 29 Oct 2025 16:23:06 +0800 Subject: [PATCH 2/4] simplify openai server setup --- tests/explorer/scheduler_test.py | 10 +---- trinity/common/config.py | 6 +-- trinity/common/models/model.py | 44 ++++++++++------------ trinity/common/models/vllm_model.py | 57 ++++++++++++++++------------- trinity/explorer/explorer.py | 12 +++--- 5 files changed, 62 insertions(+), 67 deletions(-) diff --git a/tests/explorer/scheduler_test.py b/tests/explorer/scheduler_test.py index d53d492046..306cc8e447 100644 --- a/tests/explorer/scheduler_test.py +++ b/tests/explorer/scheduler_test.py @@ -134,10 +134,7 @@ def init_process_group( ) -> None: pass - def has_api_server(self) -> bool: - return False - - def get_api_server_url(self) -> Optional[str]: + async def get_api_server_url(self) -> Optional[str]: return None @@ -161,10 +158,7 @@ def init_process_group( ) -> None: pass - def has_api_server(self) -> bool: - return True - - def get_api_server_url(self) -> str: + async def get_api_server_url(self) -> str: return "http://localhost:12345" diff --git a/trinity/common/config.py b/trinity/common/config.py index 19889477e8..b4a4837b4a 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -471,11 +471,11 @@ class InferenceModelConfig: enable_openai_api: bool = False # For tool calls in OpenAI API - enable_auto_tool_choice: bool = False + enable_auto_tool_choice: bool = True - tool_call_parser: Optional[str] = None + tool_call_parser: str = "hermes" - reasoning_parser: Optional[str] = None + reasoning_parser: str = "deepseek_r1" # ! DO NOT SET bundle_indices: str = "" diff --git a/trinity/common/models/model.py b/trinity/common/models/model.py index 2a709f2fa8..c22f6cbc99 100644 --- a/trinity/common/models/model.py +++ b/trinity/common/models/model.py @@ -51,11 +51,7 @@ def get_available_address(self) -> Tuple[str, int]: port = s.getsockname()[1] return address, port - def has_api_server(self) -> bool: - """Check if the model has an API server.""" - return False - - def get_api_server_url(self) -> Optional[str]: + async def get_api_server_url(self) -> Optional[str]: """Get the API server URL if available.""" return None @@ -106,26 +102,24 @@ def __init__( async def prepare(self) -> None: """Prepare the model wrapper.""" - if await self.model.has_api_server.remote(): - self.api_address = await self.model.get_api_server_url.remote() - if self.api_address is None: - raise RuntimeError( - "Failed to connect to the API server. Please set `enable_openai_api` to `True`." - ) - max_retries = 30 - interval = 2 # seconds - for i in range(max_retries): - try: - async with httpx.AsyncClient() as client: - response = await client.get(self.api_address + "/health", timeout=5) - if response.status_code == 200: - return - except Exception as e: - self.logger.info(f"API server not ready (attempt {i + 1}/{max_retries}): {e}") - await asyncio.sleep(interval) - raise RuntimeError( - f"API server at {self.api_address} not ready after {max_retries} attempts." - ) + self.api_address = await self.model.get_api_server_url.remote() + if self.api_address is None: + self.logger.info("API server is not enabled for inference model.") + return + max_retries = 30 + interval = 2 # seconds + for i in range(max_retries): + try: + async with httpx.AsyncClient() as client: + response = await client.get(self.api_address + "/health", timeout=5) + if response.status_code == 200: + return + except Exception as e: + self.logger.info(f"API server not ready (attempt {i + 1}/{max_retries}): {e}") + await asyncio.sleep(interval) + raise RuntimeError( + f"API server at {self.api_address} not ready after {max_retries} attempts." + ) def _record_history(self, exps: Union[Experience, List[Experience]]) -> None: """Record experiences to history.""" diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index d5174a41cb..418fd270a6 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -103,6 +103,7 @@ def __init__( self.api_server_host = None self.api_server_port = None self.api_server = None + self.async_lock = asyncio.Lock() async def _initialize_tokenizer(self): if self.tokenizer is None: @@ -455,38 +456,44 @@ async def init_process_group( ), ) - async def run_api_server(self): - """Run the OpenAI API server in a Ray actor.""" - if not self.config.enable_openai_api or not ( - self.api_server_host is None or self.api_server_port is None - ): - return # no need to run or already running - - from trinity.common.models.api.vllm_patch import run_api_server_in_ray_actor - - self.api_server_host, self.api_server_port = self.get_available_address() - self.api_server = asyncio.create_task( - run_api_server_in_ray_actor( - self.async_llm, - self.api_server_host, - self.api_server_port, - self.config.model_path, - self.config.enable_auto_tool_choice, - self.config.tool_call_parser, - self.config.reasoning_parser, - ) - ) + async def run_api_server(self) -> bool: + """Run the OpenAI API server in a Ray actor. - def has_api_server(self) -> bool: - return self.config.enable_openai_api + Returns: + success (bool): Whether the API server is started successfully. + """ + async with self.async_lock: + if not self.config.enable_openai_api: + return False # Not enabled + + if self.api_server_host is not None and self.api_server_port is not None: + return True # already running + + from trinity.common.models.api.vllm_patch import run_api_server_in_ray_actor + + api_server_host, api_server_port = self.get_available_address() + self.api_server = asyncio.create_task( + run_api_server_in_ray_actor( + self.async_llm, + api_server_host, + api_server_port, + self.config.model_path, + self.config.enable_auto_tool_choice, + self.config.tool_call_parser, + self.config.reasoning_parser, + ) + ) + self.api_server_host = api_server_host + self.api_server_port = api_server_port + return True - def get_api_server_url(self) -> Optional[str]: + async def get_api_server_url(self) -> Optional[str]: """Get the URL of the OpenAI API server. Returns: api_url (str): The URL of the OpenAI API server. """ - if not self.has_api_server(): + if not self.run_api_server(): return None return f"http://{self.api_server_host}:{self.api_server_port}" diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index e7f79999d7..038c1dd5f9 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -154,12 +154,12 @@ async def prepare(self) -> None: await self.experience_pipeline.prepare.remote() self.logger.info("Experience pipeline is ready.") # make sure all rollout models are ready - run_api_ref = [] - for model in self.models: - run_api_ref.append(model.run_api_server.remote()) - for models in self.auxiliary_models: - for model in models: - run_api_ref.append(model.run_api_server.remote()) + run_api_ref = [model.run_api_server.remote() for model in self.models] + run_api_ref.extend( + model.run_api_server.remote() + for models in self.auxiliary_models + for model in models + ) await asyncio.gather(*run_api_ref) self.logger.info("All models are ready.") From 92b3784da9b28950064ba4291682ce23da88eb2e Mon Sep 17 00:00:00 2001 From: pxc Date: Wed, 29 Oct 2025 16:58:32 +0800 Subject: [PATCH 3/4] add auxiliary model tests --- tests/explorer/explorer_test.py | 25 +++++++++++++++++++++---- tests/tools.py | 11 +++++++++++ trinity/common/config.py | 2 +- trinity/common/models/vllm_model.py | 2 +- 4 files changed, 34 insertions(+), 6 deletions(-) diff --git a/tests/explorer/explorer_test.py b/tests/explorer/explorer_test.py index dedb5a4d35..91222e9884 100644 --- a/tests/explorer/explorer_test.py +++ b/tests/explorer/explorer_test.py @@ -15,6 +15,7 @@ RayUnittestBase, RayUnittestBaseAysnc, TensorBoardParser, + get_api_model_path, get_checkpoint_path, get_model_path, get_template_config, @@ -22,7 +23,7 @@ ) from trinity.buffer import get_buffer_reader from trinity.cli.launcher import explore, run_stage -from trinity.common.config import ExperienceBufferConfig +from trinity.common.config import ExperienceBufferConfig, InferenceModelConfig from trinity.common.constants import StorageType from trinity.explorer.explorer import Explorer from trinity.manager.state_manager import StateManager @@ -31,6 +32,7 @@ class BaseExplorerCase(RayUnittestBase): def setUp(self): self.config = get_template_config() + self.config.mode = "explore" self.config.buffer.total_epochs = 2 self.config.buffer.batch_size = 4 self.config.model.model_path = get_model_path() @@ -67,10 +69,25 @@ def test_explorer(self): self.assertTrue("eval/eval_long/accuracy/max" in eval_metrics) -class TestExplorerCountdownNoEval(BaseExplorerCase): +class TestExplorerGSM8KRULERNoEval(BaseExplorerCase): def test_explorer(self): - self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown") + self.config.explorer.rollout_model.engine_num = 2 + self.config.explorer.auxiliary_models = [ + InferenceModelConfig( + model_path=get_api_model_path(), + tensor_parallel_size=1, + engine_num=2, + ) + ] + self.config.algorithm.repeat_times = 2 + self.config.buffer.total_steps = 2 + self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k_ruler") self.config.name = f"explore-no-eval-{datetime.now().strftime('%Y%m%d%H%M%S')}" + self.config.algorithm.algorithm_type = "grpo" + self.config.algorithm.advantage_fn = "grpo" + self.config.algorithm.advantage_fn_args = { + "std_threshold": 0.0001, + } self.config.check_and_update() explore(self.config) parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard")) @@ -78,7 +95,7 @@ def test_explorer(self): self.assertTrue(len(rollout_metrics) > 0) eval_metrics = parser.metric_list("eval") self.assertTrue(len(eval_metrics) == 0) - self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 8) + self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 2) class TestExplorerGSM8k(BaseExplorerCase): diff --git a/tests/tools.py b/tests/tools.py index 8d8cab98a5..62a5b4d92e 100644 --- a/tests/tools.py +++ b/tests/tools.py @@ -116,6 +116,17 @@ def get_unittest_dataset_config(dataset_name: str = "countdown", split: str = "t default_workflow_type="math_workflow", default_reward_fn_type="math_reward", ) + elif dataset_name == "gsm8k_ruler": + return TasksetConfig( + name=dataset_name, + path=os.path.join(os.path.dirname(__file__), "template", "data", "gsm8k"), + split="train", + format=FormatConfig( + prompt_key="question", + response_key="answer", + ), + default_workflow_type="math_ruler_workflow", + ) elif dataset_name == "sft_for_gsm8k": # SFT dataset with 8 samples return ExperienceBufferConfig( diff --git a/trinity/common/config.py b/trinity/common/config.py index b4a4837b4a..07f8d54e62 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -436,7 +436,7 @@ class InferenceModelConfig: # ! DO NOT SET in explorer.rollout_model, automatically set from config.model.model_path model_path: str = "" - engine_type: str = "vllm_async" + engine_type: str = "vllm" engine_num: int = 1 tensor_parallel_size: int = 1 use_v1: bool = True diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index 418fd270a6..9436bc7e06 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -493,7 +493,7 @@ async def get_api_server_url(self) -> Optional[str]: Returns: api_url (str): The URL of the OpenAI API server. """ - if not self.run_api_server(): + if not await self.run_api_server(): return None return f"http://{self.api_server_host}:{self.api_server_port}" From 7b9659fba5255a69525f7995255f4ab5cc7546fe Mon Sep 17 00:00:00 2001 From: pxc Date: Wed, 29 Oct 2025 17:39:07 +0800 Subject: [PATCH 4/4] fix config --- trinity/common/config.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/trinity/common/config.py b/trinity/common/config.py index 07f8d54e62..456ff136b5 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -471,11 +471,11 @@ class InferenceModelConfig: enable_openai_api: bool = False # For tool calls in OpenAI API - enable_auto_tool_choice: bool = True + enable_auto_tool_choice: bool = False - tool_call_parser: str = "hermes" + tool_call_parser: Optional[str] = None - reasoning_parser: str = "deepseek_r1" + reasoning_parser: Optional[str] = None # ! DO NOT SET bundle_indices: str = ""