diff --git a/tests/explorer/explorer_test.py b/tests/explorer/explorer_test.py index dedb5a4d35..91222e9884 100644 --- a/tests/explorer/explorer_test.py +++ b/tests/explorer/explorer_test.py @@ -15,6 +15,7 @@ RayUnittestBase, RayUnittestBaseAysnc, TensorBoardParser, + get_api_model_path, get_checkpoint_path, get_model_path, get_template_config, @@ -22,7 +23,7 @@ ) from trinity.buffer import get_buffer_reader from trinity.cli.launcher import explore, run_stage -from trinity.common.config import ExperienceBufferConfig +from trinity.common.config import ExperienceBufferConfig, InferenceModelConfig from trinity.common.constants import StorageType from trinity.explorer.explorer import Explorer from trinity.manager.state_manager import StateManager @@ -31,6 +32,7 @@ class BaseExplorerCase(RayUnittestBase): def setUp(self): self.config = get_template_config() + self.config.mode = "explore" self.config.buffer.total_epochs = 2 self.config.buffer.batch_size = 4 self.config.model.model_path = get_model_path() @@ -67,10 +69,25 @@ def test_explorer(self): self.assertTrue("eval/eval_long/accuracy/max" in eval_metrics) -class TestExplorerCountdownNoEval(BaseExplorerCase): +class TestExplorerGSM8KRULERNoEval(BaseExplorerCase): def test_explorer(self): - self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown") + self.config.explorer.rollout_model.engine_num = 2 + self.config.explorer.auxiliary_models = [ + InferenceModelConfig( + model_path=get_api_model_path(), + tensor_parallel_size=1, + engine_num=2, + ) + ] + self.config.algorithm.repeat_times = 2 + self.config.buffer.total_steps = 2 + self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k_ruler") self.config.name = f"explore-no-eval-{datetime.now().strftime('%Y%m%d%H%M%S')}" + self.config.algorithm.algorithm_type = "grpo" + self.config.algorithm.advantage_fn = "grpo" + self.config.algorithm.advantage_fn_args = { + "std_threshold": 0.0001, + } self.config.check_and_update() explore(self.config) parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard")) @@ -78,7 +95,7 @@ def test_explorer(self): self.assertTrue(len(rollout_metrics) > 0) eval_metrics = parser.metric_list("eval") self.assertTrue(len(eval_metrics) == 0) - self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 8) + self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 2) class TestExplorerGSM8k(BaseExplorerCase): diff --git a/tests/explorer/scheduler_test.py b/tests/explorer/scheduler_test.py index d53d492046..306cc8e447 100644 --- a/tests/explorer/scheduler_test.py +++ b/tests/explorer/scheduler_test.py @@ -134,10 +134,7 @@ def init_process_group( ) -> None: pass - def has_api_server(self) -> bool: - return False - - def get_api_server_url(self) -> Optional[str]: + async def get_api_server_url(self) -> Optional[str]: return None @@ -161,10 +158,7 @@ def init_process_group( ) -> None: pass - def has_api_server(self) -> bool: - return True - - def get_api_server_url(self) -> str: + async def get_api_server_url(self) -> str: return "http://localhost:12345" diff --git a/tests/tools.py b/tests/tools.py index 8d8cab98a5..62a5b4d92e 100644 --- a/tests/tools.py +++ b/tests/tools.py @@ -116,6 +116,17 @@ def get_unittest_dataset_config(dataset_name: str = "countdown", split: str = "t default_workflow_type="math_workflow", default_reward_fn_type="math_reward", ) + elif dataset_name == "gsm8k_ruler": + return TasksetConfig( + name=dataset_name, + path=os.path.join(os.path.dirname(__file__), "template", "data", "gsm8k"), + split="train", + format=FormatConfig( + prompt_key="question", + response_key="answer", + ), + default_workflow_type="math_ruler_workflow", + ) elif dataset_name == "sft_for_gsm8k": # SFT dataset with 8 samples return ExperienceBufferConfig( diff --git a/trinity/common/config.py b/trinity/common/config.py index 19889477e8..456ff136b5 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -436,7 +436,7 @@ class InferenceModelConfig: # ! DO NOT SET in explorer.rollout_model, automatically set from config.model.model_path model_path: str = "" - engine_type: str = "vllm_async" + engine_type: str = "vllm" engine_num: int = 1 tensor_parallel_size: int = 1 use_v1: bool = True diff --git a/trinity/common/models/__init__.py b/trinity/common/models/__init__.py index 9b12b48340..ffdc98c070 100644 --- a/trinity/common/models/__init__.py +++ b/trinity/common/models/__init__.py @@ -106,9 +106,7 @@ def create_inference_models( config=config.explorer.rollout_model, ) ) - if config.explorer.rollout_model.enable_openai_api: - for engine in rollout_engines: - engine.run_api_server.remote() + if config.explorer.rollout_model.enable_history: logger.info( "Model History recording is enabled. Please periodically extract " @@ -138,10 +136,6 @@ def create_inference_models( .remote(config=model_config) ) auxiliary_engines.append(engines) - # all auxiliary engines run api server - for engines in auxiliary_engines: - for engine in engines: - engine.run_api_server.remote() return rollout_engines, auxiliary_engines @@ -159,10 +153,10 @@ def create_debug_inference_model(config: Config) -> None: rollout_models, auxiliary_models = create_inference_models(config) # make sure models are started for m in rollout_models: - ray.get(m.get_model_path.remote()) + ray.get(m.run_api_server.remote()) for models in auxiliary_models: for m in models: - ray.get(m.get_model_path.remote()) + ray.get(m.run_api_server.remote()) logger.info( "----------------------------------------------------\n" "Inference models started successfully for debugging.\n" diff --git a/trinity/common/models/model.py b/trinity/common/models/model.py index 2a709f2fa8..c22f6cbc99 100644 --- a/trinity/common/models/model.py +++ b/trinity/common/models/model.py @@ -51,11 +51,7 @@ def get_available_address(self) -> Tuple[str, int]: port = s.getsockname()[1] return address, port - def has_api_server(self) -> bool: - """Check if the model has an API server.""" - return False - - def get_api_server_url(self) -> Optional[str]: + async def get_api_server_url(self) -> Optional[str]: """Get the API server URL if available.""" return None @@ -106,26 +102,24 @@ def __init__( async def prepare(self) -> None: """Prepare the model wrapper.""" - if await self.model.has_api_server.remote(): - self.api_address = await self.model.get_api_server_url.remote() - if self.api_address is None: - raise RuntimeError( - "Failed to connect to the API server. Please set `enable_openai_api` to `True`." - ) - max_retries = 30 - interval = 2 # seconds - for i in range(max_retries): - try: - async with httpx.AsyncClient() as client: - response = await client.get(self.api_address + "/health", timeout=5) - if response.status_code == 200: - return - except Exception as e: - self.logger.info(f"API server not ready (attempt {i + 1}/{max_retries}): {e}") - await asyncio.sleep(interval) - raise RuntimeError( - f"API server at {self.api_address} not ready after {max_retries} attempts." - ) + self.api_address = await self.model.get_api_server_url.remote() + if self.api_address is None: + self.logger.info("API server is not enabled for inference model.") + return + max_retries = 30 + interval = 2 # seconds + for i in range(max_retries): + try: + async with httpx.AsyncClient() as client: + response = await client.get(self.api_address + "/health", timeout=5) + if response.status_code == 200: + return + except Exception as e: + self.logger.info(f"API server not ready (attempt {i + 1}/{max_retries}): {e}") + await asyncio.sleep(interval) + raise RuntimeError( + f"API server at {self.api_address} not ready after {max_retries} attempts." + ) def _record_history(self, exps: Union[Experience, List[Experience]]) -> None: """Record experiences to history.""" diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index ee8e30e293..9436bc7e06 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -103,6 +103,7 @@ def __init__( self.api_server_host = None self.api_server_port = None self.api_server = None + self.async_lock = asyncio.Lock() async def _initialize_tokenizer(self): if self.tokenizer is None: @@ -455,35 +456,44 @@ async def init_process_group( ), ) - async def run_api_server(self): - """Run the OpenAI API server in a Ray actor.""" - if not (self.api_server_host is None or self.api_server_port is None): - raise RuntimeError("API server is already running.") - from trinity.common.models.api.vllm_patch import run_api_server_in_ray_actor - - self.api_server_host, self.api_server_port = self.get_available_address() - self.api_server = asyncio.create_task( - run_api_server_in_ray_actor( - self.async_llm, - self.api_server_host, - self.api_server_port, - self.config.model_path, - self.config.enable_auto_tool_choice, - self.config.tool_call_parser, - self.config.reasoning_parser, - ) - ) + async def run_api_server(self) -> bool: + """Run the OpenAI API server in a Ray actor. - def has_api_server(self) -> bool: - return self.config.enable_openai_api + Returns: + success (bool): Whether the API server is started successfully. + """ + async with self.async_lock: + if not self.config.enable_openai_api: + return False # Not enabled + + if self.api_server_host is not None and self.api_server_port is not None: + return True # already running + + from trinity.common.models.api.vllm_patch import run_api_server_in_ray_actor + + api_server_host, api_server_port = self.get_available_address() + self.api_server = asyncio.create_task( + run_api_server_in_ray_actor( + self.async_llm, + api_server_host, + api_server_port, + self.config.model_path, + self.config.enable_auto_tool_choice, + self.config.tool_call_parser, + self.config.reasoning_parser, + ) + ) + self.api_server_host = api_server_host + self.api_server_port = api_server_port + return True - def get_api_server_url(self) -> Optional[str]: + async def get_api_server_url(self) -> Optional[str]: """Get the URL of the OpenAI API server. Returns: api_url (str): The URL of the OpenAI API server. """ - if not self.has_api_server(): + if not await self.run_api_server(): return None return f"http://{self.api_server_host}:{self.api_server_port}" diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index bf748e4877..038c1dd5f9 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -150,12 +150,18 @@ async def _nccl_weights_update(self): async def prepare(self) -> None: """Preparation before running.""" try: + # prepare experience pipeline await self.experience_pipeline.prepare.remote() self.logger.info("Experience pipeline is ready.") # make sure all rollout models are ready - model_ready_ref = [model.__ray_ready__.remote() for model in self.models] - await asyncio.gather(*model_ready_ref) - self.logger.info("All rollout models are ready.") + run_api_ref = [model.run_api_server.remote() for model in self.models] + run_api_ref.extend( + model.run_api_server.remote() + for models in self.auxiliary_models + for model in models + ) + await asyncio.gather(*run_api_ref) + self.logger.info("All models are ready.") if not self.use_nccl_sync: master_address, master_port = await self.models[0].get_available_address.remote()