From 52adb78eb0412eb82b28b591d5a7cfbbb0af8ac3 Mon Sep 17 00:00:00 2001 From: pxc Date: Wed, 2 Jul 2025 16:22:06 +0800 Subject: [PATCH 01/20] add new scheduler --- trinity/common/config.py | 2 +- trinity/explorer/scheduler.py | 261 ++++++++++++++++++++++++++++++++++ 2 files changed, 262 insertions(+), 1 deletion(-) create mode 100644 trinity/explorer/scheduler.py diff --git a/trinity/common/config.py b/trinity/common/config.py index 100d74d9da..a6d7eba036 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -306,7 +306,7 @@ class ExplorerConfig: runner_num: int = 1 max_timeout: int = 900 # wait each task for 15 minutes max_retry_times: int = 2 # retry each task for 2 times if it fails or timeout - env_vars: dict = field(default_factory=dict) + runner_per_model: int = 8 # for inference models # for rollout model diff --git a/trinity/explorer/scheduler.py b/trinity/explorer/scheduler.py new file mode 100644 index 0000000000..f248378f6c --- /dev/null +++ b/trinity/explorer/scheduler.py @@ -0,0 +1,261 @@ +"""Scheduler for rollout tasks.""" + +import asyncio +import time +from typing import List, Dict, Tuple, Optional +from collections import defaultdict, deque +import traceback + +import ray + +from trinity.common.models import InferenceModel +from trinity.common.config import Config +from trinity.common.workflows import Task +from trinity.explorer.workflow_runner import WorkflowRunner, Status +from trinity.utils.log import get_logger + + + +class RunnerWrapper: + + def __init__(self, runner: WorkflowRunner, runner_id: int): + self.logger = get_logger(__name__) + self.runner = runner + self.runner_id = runner_id + self.is_busy = False + self.current_task: Task = None + + async def run_with_retry(self, task: Task, retry_times: int) -> Tuple[Status, int]: + """ + Returns: + `Status`: The return status of the task. + `int`: The runner_id of current runner. + """ + last_exception_msg = None + self.is_busy = True + self.current_task = task + start_time = time.time() + try: + for attempt in range(retry_times + 1): + try: + status = await self.runner.run.remote(task) + if status.ok: + break + else: + self.logger.error(status.message) + except Exception: + last_exception_msg = traceback.format_exception() + self.logger.warning( + f"Task execution attempt {attempt + 1} failed:\n{last_exception_msg}" + ) + status = Status(ok=False, metric=dict(), message=last_exception_msg) + finally: + end_time = time.time() + status.metric["task_run_time"] = end_time - start_time + self.is_busy = False + self.current_task = None + return status, self.runner_id + + +class Scheduler: + """Scheduler for rollout tasks.""" + + def __init__( + self, + config: Config, + rollout_model: List[InferenceModel], + auxiliary_models: Optional[List[List[InferenceModel]]] = None, + ): + self.logger = get_logger(__name__) + self.config = config + self.rollout_model = rollout_model + self.auxiliary_models = auxiliary_models or [] + self.namespace = ray.get_runtime_context().namespace + self.timeout = config.explorer.max_timeout + self.max_retry_times = config.explorer.max_retry_times + self.running = False + + self.runner_num = len(rollout_model) * config.explorer.runner_per_model + self.runners: Dict[int, RunnerWrapper] = dict() + self.idle_runners = set() + self.busy_runners = dict() + + self.pending_tasks: Dict[int, deque] = defaultdict(deque) # step -> tasks + self.running_tasks: Dict[int, set[asyncio.Future]] = defaultdict(set) # step -> futures + self.completed_tasks: Dict[int, deque[Status]] = defaultdict(deque) # step -> results + + self.scheduler_task: Optional[asyncio.Task] = None + self.running = False + + self.total_scheduled = 0 + self.total_completed = 0 + for i in range(self.runner_num): + self._create_runner(i) + + async def _create_runner( + self, + runner_id: int, + ) -> None: + runner = RunnerWrapper( + runner=( + ray.remote(WorkflowRunner) + .options( + namespace=self.namespace, + scheduling_strategy="SPREAD", + ) + .remote( + self.config, + self.rollout_model[runner_id % len(self.rollout_model)], + [ + self.auxiliary_models[j][runner_id % len(self.auxiliary_models[j])] + for j in range(len(self.auxiliary_models)) + ], + ) + ), + runner_id=runner_id, + ) + self.runners[runner_id] = runner + self.idle_runners.add(runner_id) + + def _restart_runner(self, runner_id: int): + """Restart a runner.""" + try: + ray.kill(self.runners[runner_id]) + except: + pass + + self.create_runner(runner_id) + + + async def _scheduler_loop(self) -> None: + self.logger.info("Scheduler loop started.") + while self.running: + try: + await self._schedule_pending_tasks() + await self._check_completed_tasks() + await asyncio.sleep(0.01) + except Exception: + self.logger.error(f"Error in scheduler loop:\n{traceback.format_exc()}") + await asyncio.sleep(0.1) + self.logger.info("Scheduler loop stopped.") + + async def _schedule_pending_tasks(self) -> None: + if not self.idle_runners: + return + + for step in sorted(self.pending_tasks.keys()): + task_queue = self.pending_tasks[step] + + while task_queue and self.idle_runners: + task = task_queue.pop() + runner_id = self.idle_runners.pop() + self.busy_runners[runner_id] = (task, step) + self.running_tasks[step].add( + asyncio.create_task(self.runners[runner_id].run_with_retry(task)) + ) + + if not task_queue: + del self.pending_tasks[step] + + async def _check_completed_tasks(self) -> None: + for step in list(self.running_tasks.keys()): + futures = self.running_tasks[step] + + for future in list(futures): + if future.done(): + futures.remove(future) + try: + task_result, runner_id = await future + self.completed_tasks[step].appendleft(task_result) + self.busy_runners.pop(runner_id) + self.idle_runners.add(runner_id) + + self.logger.debug( + f"Task completed (step {step}), success: {task_result.success}" + ) + + except Exception as e: + self.logger.error(f"Error getting task result: {e}") + + if not futures: + del self.running_tasks[step] + + async def start(self) -> None: + if self.running: + return + self.running = True + await asyncio.gather([self._create_runner(i) for i in range(self.runner_num)]) + self.scheduler_task = asyncio.create_task(self._scheduler_loop()) + + async def stop(self) -> None: + if not self.running: + return + + self.running = False + all_running_futures = [] + for futures in self.running_tasks.values(): + all_running_futures.extend(futures) + + if all_running_futures: + self.logger.info(f"Waiting for {len(all_running_futures)} running tasks to complete...") + await asyncio.gather(*all_running_futures, return_exceptions=True) + + if self.scheduler_task: + self.scheduler_task.cancel() + try: + await self.scheduler_task + except asyncio.CancelledError: + pass + + self.logger.info("Scheduler stopped") + + def schedule(self, tasks: List[Task], step: int) -> None: + """Schedule the provided tasks. + + Args: + tasks (`List[Task]`): The tasks to schedule. + step (`int`): The step number of provided tasks. + """ + if not tasks: + return + for task in tasks: + self.pending_tasks[step].appendleft(task) + + + async def get_results( + self, step: int, min_num: Optional[int] = None, timeout: Optional[float] = None + ) -> List[Dict]: + """Get the result of tasks at the specific step. + + Args: + step (`int`): Only wait for tasks at this step. + min_num (`int`): The minimum number of tasks to wait for. If `None`, wait for all tasks at `step`. + timeout (`float`): The timeout for waiting for tasks to finish. If `None`, wait for default timeout. + """ + timeout = timeout or self.timeout + start_time = time.time() + if min_num is None: + min_num = len(self.pending_tasks[step]) + len(self.running_tasks[step]) + len(self.completed_tasks[step]) + self.logger.debug(f"Waiting for {min_num} tasks to complete...") + + while time.time() - start_time < timeout: + completed_count = len(self.completed_tasks[step]) + if completed_count >= min_num: + break + await asyncio.sleep(0.1) + + results = [] + for _ in range(min_num): + if len(self.completed_tasks[step]) > 0: + results.append(self.completed_tasks[step].pop()) + + if not self.completed_tasks[step]: + del self.completed_tasks[step] + + completed_count = len(results) + if completed_count < min_num: + self.logger.warning( + f"Timeout reached, only {completed_count}/{min_num} tasks completed" + ) + + return results From ed90e01f8d783bb78343d1f6b0f3fc155f116377 Mon Sep 17 00:00:00 2001 From: pxc Date: Wed, 2 Jul 2025 18:25:37 +0800 Subject: [PATCH 02/20] smoke test scheduler --- tests/explorer/scheduler_test.py | 165 ++++++++++++++++++++++++++++++ trinity/explorer/scheduler.py | 170 ++++++++++++++++++++++--------- 2 files changed, 287 insertions(+), 48 deletions(-) create mode 100644 tests/explorer/scheduler_test.py diff --git a/tests/explorer/scheduler_test.py b/tests/explorer/scheduler_test.py new file mode 100644 index 0000000000..e3a8d06621 --- /dev/null +++ b/tests/explorer/scheduler_test.py @@ -0,0 +1,165 @@ +import time +import unittest +from typing import List, Tuple + +import ray +import torch + +from tests.tools import get_template_config +from trinity.buffer.reader.queue_reader import QueueReader +from trinity.common.config import StorageConfig +from trinity.common.constants import StorageType +from trinity.common.experience import Experience +from trinity.common.models.model import InferenceModel +from trinity.common.workflows import Task +from trinity.common.workflows.workflow import WORKFLOWS, Workflow +from trinity.explorer.scheduler import Scheduler + + +@WORKFLOWS.register_module("dummy_workflow") +class DummyWorkflow(Workflow): + def __init__(self, model, task, auxiliary_models): + super().__init__(model, task, auxiliary_models) + self.error_type = task.raw_task.get("error_type", "") + self.seconds = None + if "timeout" in self.error_type: + self.seconds = int(self.error_type.split("_")[-1]) + + def run(self) -> List[Experience]: + if "timeout" in self.error_type: + time.sleep(self.seconds) + elif self.error_type == "exception": + raise ValueError("Exception occurred") + elif self.error_type == "exit": + exit(1) + elif self.error_type == "auxiliary_models": + assert self.auxiliary_models is not None and len(self.auxiliary_models) == 2 + return [Experience(tokens=torch.zeros(5), prompt_length=2, prompt_text=self.error_type)] + + +@ray.remote +class DummyModel(InferenceModel): + def sync_model(self, model_version, update_weight_args_list): + return True + + def get_model_version(self): + return 0 + + def init_process_group( + self, + master_address: str, + master_port: int, + rank_offset: int, + world_size: int, + group_name: str, + backend: str = "nccl", + timeout: int = 1200, + update_with_checkpoint: bool = True, + ) -> None: + pass + + +@ray.remote +class DummyAuxiliaryModel(InferenceModel): + def sync_model(self, model_version, update_weight_args_list): + return True + + def get_model_version(self): + return 0 + + def init_process_group( + self, + master_address: str, + master_port: int, + rank_offset: int, + world_size: int, + group_name: str, + backend: str = "nccl", + timeout: int = 1200, + update_with_checkpoint: bool = True, + ) -> None: + pass + + def has_api_server(self) -> bool: + return True + + def api_server_ready(self) -> Tuple[str, str]: + return "http://localhosts:12345", "placeholder" + + +def generate_tasks(total_num: int, timeout_num: int = 0, exception_num: int = 0): + tasks = [Task(workflow=DummyWorkflow, raw_task={}) for _ in range(total_num)] + tasks.extend( + [ + Task( + workflow=DummyWorkflow, + raw_task={"error_type": "timeout", "timeout": 5}, + ) + for _ in range(timeout_num) + ] + ) + tasks.extend( + [ + Task( + workflow=DummyWorkflow, + raw_task={"error_type": "exception"}, + ) + for _ in range(exception_num) + ] + ) + return tasks + + +class SchedulerTest(unittest.IsolatedAsyncioTestCase): + def setUp(self): + ray.init(ignore_reinit_error=True) + self.config = get_template_config() + self.config.explorer.max_retry_times = 1 + self.config.explorer.max_timeout = 5 + self.config.explorer.runner_per_model = 2 + self.config.buffer.read_batch_size = 2 + self.config.buffer.pad_token_id = 0 + self.config.buffer.explorer_output = ( + self.config.buffer.trainer_input.experience_buffer + ) = StorageConfig( + name="test", + storage_type=StorageType.QUEUE, + algorithm_type="ppo", + path="", + ) + self.queue = QueueReader( + self.config.buffer.trainer_input.experience_buffer, self.config.buffer + ) + + async def test_scheduler(self): + scheduler = Scheduler(self.config, [DummyModel.remote(), DummyModel.remote()]) + await scheduler.start() + tasks = generate_tasks(8) + scheduler.schedule(tasks, step=0) + self.assertTrue(scheduler.has_step(0)) + results = await scheduler.get_results(step=0, min_num=8, timeout=20) + self.assertEqual(len(results), 8) + scheduler.schedule(tasks, step=1) + scheduler.schedule(tasks[:4], step=2) + self.assertFalse(scheduler.has_step(0)) + results = await scheduler.get_results(step=0, min_num=8) + self.assertFalse(scheduler.has_step(0)) + self.assertEqual(len(results), 0) # step 0 has no more tasks + self.assertFalse(scheduler.has_step(0)) + self.assertTrue(scheduler.has_step(1)) + self.assertTrue(scheduler.has_step(2)) + await scheduler.wait_all() + st = time.time() + results = await scheduler.get_results(step=1) + et = time.time() + self.assertTrue(et - st < 1.0) + self.assertEqual(len(results), 8) + self.assertFalse(scheduler.has_step(1)) + self.assertTrue(scheduler.has_step(2)) + st = time.time() + results = await scheduler.get_results(step=2) + et = time.time() + self.assertTrue(et - st < 1.0) + self.assertEqual(len(results), 4) + self.assertFalse(scheduler.has_step(2)) + await scheduler.stop() diff --git a/trinity/explorer/scheduler.py b/trinity/explorer/scheduler.py index f248378f6c..9195768590 100644 --- a/trinity/explorer/scheduler.py +++ b/trinity/explorer/scheduler.py @@ -2,49 +2,74 @@ import asyncio import time -from typing import List, Dict, Tuple, Optional -from collections import defaultdict, deque import traceback +from collections import defaultdict, deque +from typing import Dict, List, Optional, Tuple import ray -from trinity.common.models import InferenceModel from trinity.common.config import Config +from trinity.common.models import InferenceModel from trinity.common.workflows import Task -from trinity.explorer.workflow_runner import WorkflowRunner, Status +from trinity.explorer.workflow_runner import Status, WorkflowRunner from trinity.utils.log import get_logger - class RunnerWrapper: + """A wrapper for a WorkflowRunner""" - def __init__(self, runner: WorkflowRunner, runner_id: int): + def __init__( + self, + runner_id: int, + rollout_model: InferenceModel, + auxiliary_models: List[InferenceModel], + config: Config, + ): self.logger = get_logger(__name__) - self.runner = runner self.runner_id = runner_id - self.is_busy = False - self.current_task: Task = None + self.rollout_model = rollout_model + self.auxiliary_models = auxiliary_models + self.config = config + self.retry_times = config.explorer.max_retry_times + self.timeout = config.explorer.max_timeout + self.namespace = ray.get_runtime_context().namespace + self.runner = self._create_runner() + + def _create_runner(self): + return ( + ray.remote(WorkflowRunner) + .options( + namespace=self.namespace, + scheduling_strategy="SPREAD", + ) + .remote(self.config, self.rollout_model, self.auxiliary_models) + ) - async def run_with_retry(self, task: Task, retry_times: int) -> Tuple[Status, int]: + async def run_with_retry(self, task: Task) -> Tuple[Status, int]: """ Returns: `Status`: The return status of the task. `int`: The runner_id of current runner. """ last_exception_msg = None - self.is_busy = True - self.current_task = task start_time = time.time() + status = Status(ok=False, metric=dict()) try: - for attempt in range(retry_times + 1): + for attempt in range(self.retry_times + 1): try: - status = await self.runner.run.remote(task) + status = await asyncio.wait_for(self.runner.run_task.remote(task), self.timeout) if status.ok: break else: self.logger.error(status.message) + except asyncio.TimeoutError: + self.logger.error(f"Timeout when running task: {task}") + self.restart_runner() + status = Status( + ok=False, metric=dict(), message=f"Timeout when running task: {task}" + ) except Exception: - last_exception_msg = traceback.format_exception() + last_exception_msg = traceback.format_exc() self.logger.warning( f"Task execution attempt {attempt + 1} failed:\n{last_exception_msg}" ) @@ -52,10 +77,15 @@ async def run_with_retry(self, task: Task, retry_times: int) -> Tuple[Status, in finally: end_time = time.time() status.metric["task_run_time"] = end_time - start_time - self.is_busy = False - self.current_task = None return status, self.runner_id + def restart_runner(self): + try: + ray.kill(self.runner) + except Exception: + pass + self.runner = self._create_runner() + class Scheduler: """Scheduler for rollout tasks.""" @@ -89,43 +119,33 @@ def __init__( self.total_scheduled = 0 self.total_completed = 0 - for i in range(self.runner_num): - self._create_runner(i) - async def _create_runner( + def _create_runner( self, runner_id: int, - ) -> None: + ): runner = RunnerWrapper( - runner=( - ray.remote(WorkflowRunner) - .options( - namespace=self.namespace, - scheduling_strategy="SPREAD", - ) - .remote( - self.config, - self.rollout_model[runner_id % len(self.rollout_model)], - [ - self.auxiliary_models[j][runner_id % len(self.auxiliary_models[j])] - for j in range(len(self.auxiliary_models)) - ], - ) - ), runner_id=runner_id, + rollout_model=self.rollout_model[runner_id % len(self.rollout_model)], + auxiliary_models=[ + self.auxiliary_models[j][runner_id % len(self.auxiliary_models[j])] + for j in range(len(self.auxiliary_models)) + ], + config=self.config, ) self.runners[runner_id] = runner self.idle_runners.add(runner_id) def _restart_runner(self, runner_id: int): """Restart a runner.""" - try: - ray.kill(self.runners[runner_id]) - except: - pass - - self.create_runner(runner_id) + self.runners[runner_id].restart_runner() + + if runner_id in self.busy_runners: + task, idx = self.busy_runners.pop(runner_id) + self.logger.warning(f"Runner failed to run task at step {idx}: {task.raw_task}") + self.idle_runners.add(runner_id) + self.logger.info(f"Runner {runner_id} restarted.") async def _scheduler_loop(self) -> None: self.logger.info("Scheduler loop started.") @@ -171,7 +191,7 @@ async def _check_completed_tasks(self) -> None: self.idle_runners.add(runner_id) self.logger.debug( - f"Task completed (step {step}), success: {task_result.success}" + f"Task completed (step {step}), success: {task_result.ok}" ) except Exception as e: @@ -184,8 +204,12 @@ async def start(self) -> None: if self.running: return self.running = True - await asyncio.gather([self._create_runner(i) for i in range(self.runner_num)]) + for i in range(self.runner_num): + self._create_runner(i) self.scheduler_task = asyncio.create_task(self._scheduler_loop()) + for _, runner in self.runners.items(): + await runner.runner.__ray_ready__.remote() + self.logger.info(f"Starting Scheduler with {self.runner_num} runners") async def stop(self) -> None: if not self.running: @@ -206,7 +230,6 @@ async def stop(self) -> None: await self.scheduler_task except asyncio.CancelledError: pass - self.logger.info("Scheduler stopped") def schedule(self, tasks: List[Task], step: int) -> None: @@ -221,10 +244,9 @@ def schedule(self, tasks: List[Task], step: int) -> None: for task in tasks: self.pending_tasks[step].appendleft(task) - async def get_results( self, step: int, min_num: Optional[int] = None, timeout: Optional[float] = None - ) -> List[Dict]: + ) -> List[Status]: """Get the result of tasks at the specific step. Args: @@ -235,7 +257,14 @@ async def get_results( timeout = timeout or self.timeout start_time = time.time() if min_num is None: - min_num = len(self.pending_tasks[step]) + len(self.running_tasks[step]) + len(self.completed_tasks[step]) + min_num = 0 + if step in self.pending_tasks: + min_num += len(self.pending_tasks[step]) + if step in self.running_tasks: + min_num += len(self.running_tasks[step]) + if step in self.completed_tasks: + min_num += len(self.completed_tasks[step]) + self.logger.debug(f"Waiting for {min_num} tasks to complete...") while time.time() - start_time < timeout: @@ -244,6 +273,12 @@ async def get_results( break await asyncio.sleep(0.1) + if time.time() - start_time > timeout: + self.logger.error(f"Timed out waiting for tasks to complete after {timeout} seconds") + busy_runner_ids = list(self.busy_runners.keys()) + for runner_id in busy_runner_ids: + self._restart_runner(runner_id) + results = [] for _ in range(min_num): if len(self.completed_tasks[step]) > 0: @@ -259,3 +294,42 @@ async def get_results( ) return results + + def has_step(self, step: int) -> bool: + return ( + step in self.completed_tasks or step in self.pending_tasks or step in self.running_tasks + ) + + async def wait_all(self, timeout: Optional[float] = None) -> None: + """Wait for all tasks to complete without poping results. If timeout reached, raise TimeoutError.""" + timeout = timeout or self.timeout + start_time = time.time() + + self.logger.debug("Waiting for all tasks to complete...") + + while time.time() - start_time < timeout: + has_pending = bool(self.pending_tasks) + has_running = bool(self.running_tasks) + + if not has_pending and not has_running: + self.logger.debug("All tasks completed successfully") + return + + pending_count = sum(len(tasks) for tasks in self.pending_tasks.values()) + running_count = sum(len(futures) for futures in self.running_tasks.values()) + + self.logger.debug(f"Pending tasks: {pending_count}, Running tasks: {running_count}") + + await asyncio.sleep(0.1) + + pending_count = sum(len(tasks) for tasks in self.pending_tasks.values()) + running_count = sum(len(futures) for futures in self.running_tasks.values()) + + error_msg = f"Timeout after {timeout} seconds. Still have {pending_count} pending tasks and {running_count} running tasks" + self.logger.error(error_msg) + + busy_runner_ids = list(self.busy_runners.keys()) + for runner_id in busy_runner_ids: + self._restart_runner(runner_id) + + raise TimeoutError(error_msg) From 1aabed7b60a0986ac1b24303e29d3f6287456eff Mon Sep 17 00:00:00 2001 From: pxc Date: Wed, 2 Jul 2025 20:43:28 +0800 Subject: [PATCH 03/20] add more tests to scheduler --- tests/explorer/scheduler_test.py | 188 ++++++++++++++++++++++++++++++- trinity/explorer/scheduler.py | 43 ++++--- 2 files changed, 213 insertions(+), 18 deletions(-) diff --git a/tests/explorer/scheduler_test.py b/tests/explorer/scheduler_test.py index e3a8d06621..fa7fd99d85 100644 --- a/tests/explorer/scheduler_test.py +++ b/tests/explorer/scheduler_test.py @@ -1,3 +1,4 @@ +import asyncio import time import unittest from typing import List, Tuple @@ -23,7 +24,12 @@ def __init__(self, model, task, auxiliary_models): self.error_type = task.raw_task.get("error_type", "") self.seconds = None if "timeout" in self.error_type: - self.seconds = int(self.error_type.split("_")[-1]) + # 提取超时时间,格式如 "timeout_5" + parts = self.error_type.split("_") + if len(parts) > 1: + self.seconds = int(parts[-1]) + else: + self.seconds = 10 # 默认超时时间 def run(self) -> List[Experience]: if "timeout" in self.error_type: @@ -34,7 +40,13 @@ def run(self) -> List[Experience]: exit(1) elif self.error_type == "auxiliary_models": assert self.auxiliary_models is not None and len(self.auxiliary_models) == 2 - return [Experience(tokens=torch.zeros(5), prompt_length=2, prompt_text=self.error_type)] + + # 返回一个成功的结果 + return [ + Experience( + tokens=torch.zeros(5), prompt_length=2, prompt_text=self.error_type or "success" + ) + ] @ray.remote @@ -87,17 +99,29 @@ def api_server_ready(self) -> Tuple[str, str]: return "http://localhosts:12345", "placeholder" -def generate_tasks(total_num: int, timeout_num: int = 0, exception_num: int = 0): +def generate_tasks( + total_num: int, timeout_num: int = 0, exception_num: int = 0, timeout_seconds: int = 10 +): + """Generate some tasks for testing + + Args: + total_num: number of normal tasks + timeout_num: number of timeout tasks + exception_num: number of exception tasks + timeout_seconds: the timeout for timeout tasks + """ tasks = [Task(workflow=DummyWorkflow, raw_task={}) for _ in range(total_num)] + tasks.extend( [ Task( workflow=DummyWorkflow, - raw_task={"error_type": "timeout", "timeout": 5}, + raw_task={"error_type": f"timeout_{timeout_seconds}"}, ) for _ in range(timeout_num) ] ) + tasks.extend( [ Task( @@ -107,6 +131,7 @@ def generate_tasks(total_num: int, timeout_num: int = 0, exception_num: int = 0) for _ in range(exception_num) ] ) + return tasks @@ -131,7 +156,154 @@ def setUp(self): self.config.buffer.trainer_input.experience_buffer, self.config.buffer ) - async def test_scheduler(self): + async def test_get_results(self): + scheduler = Scheduler(self.config, [DummyModel.remote(), DummyModel.remote()]) + await scheduler.start() + + # tasks = generate_tasks(8) + # scheduler.schedule(tasks, step=0) + + # results = await scheduler.get_results(step=0, min_num=8, timeout=20) + # self.assertEqual(len(results), 8) + + # for result in results: + # self.assertTrue(result.ok) + + # for step in range(1, 4): + # tasks = generate_tasks(4) + # scheduler.schedule(tasks, step=step) + + # for step in range(1, 4): + # self.assertTrue(scheduler.has_step(step)) + # results = await scheduler.get_results(step=step, min_num=4, timeout=10) + # self.assertEqual(len(results), 4) + # self.assertFalse(scheduler.has_step(step)) + + # tasks = generate_tasks(3) + # scheduler.schedule(tasks, step=4) + # self.assertTrue(scheduler.has_step(4)) + # results = await scheduler.get_results(step=4) + # self.assertEqual(len(results), 3) + # self.assertFalse(scheduler.has_step(4)) + + # test timeout + tasks = generate_tasks(2, timeout_num=2, timeout_seconds=10) + scheduler.schedule(tasks, step=0) + + start_time = time.time() + results = await scheduler.get_results(step=0, min_num=4, timeout=3) + end_time = time.time() + + self.assertLessEqual(end_time - start_time, 5) + self.assertEqual(len(results), 2) + + # test run tasks after timeout + tasks = generate_tasks(4) + scheduler.schedule(tasks, step=0) + + # actor restart is slow, set a big timeout + results = await scheduler.get_results(step=0, timeout=20) + self.assertEqual(len(results), 4) + + success_count = sum(1 for r in results if r.ok) + + self.assertEqual(success_count, sum(1 for r in results if r.ok)) + + # test exception tasks + tasks = generate_tasks(1, exception_num=3) + scheduler.schedule(tasks, step=1) + results = await scheduler.get_results(step=1, timeout=5) + self.assertEqual(len(results), 4) + + success_count = sum(1 for r in results if r.ok) + self.assertEqual(success_count, 1) + + await scheduler.stop() + + async def test_wait_all(self): + """Test wait all""" + scheduler = Scheduler(self.config, [DummyModel.remote(), DummyModel.remote()]) + await scheduler.start() + + tasks1 = generate_tasks(4) + tasks2 = generate_tasks(3) + scheduler.schedule(tasks1, step=0) + scheduler.schedule(tasks2, step=1) + + start_time = time.time() + await scheduler.wait_all(timeout=10.0) + end_time = time.time() + + self.assertLess(end_time - start_time, 5.0) + + self.assertEqual(len(scheduler.pending_tasks), 0) + self.assertEqual(len(scheduler.running_tasks), 0) + + results0 = await scheduler.get_results(step=0, min_num=4, timeout=1) + results1 = await scheduler.get_results(step=1, min_num=3, timeout=1) + self.assertEqual(len(results0), 4) + self.assertEqual(len(results1), 3) + + # test timeout + tasks = generate_tasks(2, timeout_num=2, timeout_seconds=10) + scheduler.schedule(tasks, step=0) + + start_time = time.time() + with self.assertRaises(TimeoutError): + await scheduler.wait_all(timeout=3.0) + end_time = time.time() + + self.assertGreaterEqual(end_time - start_time, 2.8) + self.assertLessEqual(end_time - start_time, 4.0) + + # test empty scenario + + start_time = time.time() + await scheduler.wait_all(timeout=5.0) + end_time = time.time() + + self.assertLess(end_time - start_time, 1.0) + await scheduler.stop() + + async def test_concurrent_operations(self): + scheduler = Scheduler(self.config, [DummyModel.remote(), DummyModel.remote()]) + await scheduler.start() + + async def schedule_tasks(step, num_tasks): + tasks = generate_tasks(num_tasks) + scheduler.schedule(tasks, step=step) + return await scheduler.get_results(step=step, min_num=num_tasks, timeout=10) + + results = await asyncio.gather( + schedule_tasks(0, 3), + schedule_tasks(1, 4), + schedule_tasks(2, 2), + ) + + self.assertEqual(len(results[0]), 3) + self.assertEqual(len(results[1]), 4) + self.assertEqual(len(results[2]), 2) + + await scheduler.stop() + + async def test_scheduler_restart_after_stop(self): + scheduler = Scheduler(self.config, [DummyModel.remote()]) + + await scheduler.start() + tasks = generate_tasks(2) + scheduler.schedule(tasks, step=0) + results = await scheduler.get_results(step=0, min_num=2, timeout=10) + self.assertEqual(len(results), 2) + await scheduler.stop() + + await scheduler.start() + tasks = generate_tasks(3) + scheduler.schedule(tasks, step=1) + results = await scheduler.get_results(step=1, min_num=3, timeout=10) + self.assertEqual(len(results), 3) + await scheduler.stop() + + async def test_scheduler_all_methods(self): scheduler = Scheduler(self.config, [DummyModel.remote(), DummyModel.remote()]) await scheduler.start() tasks = generate_tasks(8) @@ -163,3 +335,9 @@ async def test_scheduler(self): self.assertEqual(len(results), 4) self.assertFalse(scheduler.has_step(2)) await scheduler.stop() + + def tearDown(self): + try: + ray.shutdown() + except Exception: + pass diff --git a/trinity/explorer/scheduler.py b/trinity/explorer/scheduler.py index 9195768590..119674bef1 100644 --- a/trinity/explorer/scheduler.py +++ b/trinity/explorer/scheduler.py @@ -52,6 +52,7 @@ async def run_with_retry(self, task: Task) -> Tuple[Status, int]: `int`: The runner_id of current runner. """ last_exception_msg = None + await self.runner.__ray_ready__.remote() start_time = time.time() status = Status(ok=False, metric=dict()) try: @@ -63,11 +64,11 @@ async def run_with_retry(self, task: Task) -> Tuple[Status, int]: else: self.logger.error(status.message) except asyncio.TimeoutError: - self.logger.error(f"Timeout when running task: {task}") - self.restart_runner() - status = Status( - ok=False, metric=dict(), message=f"Timeout when running task: {task}" + last_exception_msg = ( + f"Timeout when running task at runner {self.runner_id}: {task}" ) + self.logger.error(last_exception_msg) + status = Status(ok=False, metric=dict(), message=last_exception_msg) except Exception: last_exception_msg = traceback.format_exc() self.logger.warning( @@ -80,11 +81,12 @@ async def run_with_retry(self, task: Task) -> Tuple[Status, int]: return status, self.runner_id def restart_runner(self): + old_runner = self.runner + self.runner = self._create_runner() try: - ray.kill(self.runner) + ray.kill(old_runner) except Exception: pass - self.runner = self._create_runner() class Scheduler: @@ -107,8 +109,8 @@ def __init__( self.runner_num = len(rollout_model) * config.explorer.runner_per_model self.runners: Dict[int, RunnerWrapper] = dict() - self.idle_runners = set() - self.busy_runners = dict() + self.idle_runners = set() # runner_id + self.busy_runners = dict() # runner_id -> (task, step) self.pending_tasks: Dict[int, deque] = defaultdict(deque) # step -> tasks self.running_tasks: Dict[int, set[asyncio.Future]] = defaultdict(set) # step -> futures @@ -142,7 +144,9 @@ def _restart_runner(self, runner_id: int): if runner_id in self.busy_runners: task, idx = self.busy_runners.pop(runner_id) - self.logger.warning(f"Runner failed to run task at step {idx}: {task.raw_task}") + self.logger.warning( + f"Runner {runner_id} failed to run task at step {idx}: {task.raw_task}" + ) self.idle_runners.add(runner_id) self.logger.info(f"Runner {runner_id} restarted.") @@ -200,6 +204,16 @@ async def _check_completed_tasks(self) -> None: if not futures: del self.running_tasks[step] + def _clear_timeout_tasks(self, step: int) -> None: + if step in self.pending_tasks: + self.logger.info(f"Clear timeout pending tasks at step {step}.") + del self.pending_tasks[step] + if step in self.running_tasks: + self.logger.info(f"Clear timeout running tasks at step {step}.") + for future in self.running_tasks[step]: + future.cancel() + del self.running_tasks[step] + async def start(self) -> None: if self.running: return @@ -275,9 +289,10 @@ async def get_results( if time.time() - start_time > timeout: self.logger.error(f"Timed out waiting for tasks to complete after {timeout} seconds") - busy_runner_ids = list(self.busy_runners.keys()) - for runner_id in busy_runner_ids: - self._restart_runner(runner_id) + self._clear_timeout_tasks(step=step) + for runner_id in list(self.busy_runners.keys()): + if self.busy_runners[runner_id][1] == step: + self._restart_runner(runner_id) results = [] for _ in range(min_num): @@ -324,8 +339,10 @@ async def wait_all(self, timeout: Optional[float] = None) -> None: pending_count = sum(len(tasks) for tasks in self.pending_tasks.values()) running_count = sum(len(futures) for futures in self.running_tasks.values()) + for step in self.pending_tasks.keys() | self.running_tasks.keys(): + self._clear_timeout_tasks(step) - error_msg = f"Timeout after {timeout} seconds. Still have {pending_count} pending tasks and {running_count} running tasks" + error_msg = f"Timeout after {timeout} seconds. Still have {pending_count} pending tasks and {running_count} running tasks." self.logger.error(error_msg) busy_runner_ids = list(self.busy_runners.keys()) From d1149bae68ac42f95b125fa66dccfbd18ff10f74 Mon Sep 17 00:00:00 2001 From: pxc Date: Thu, 3 Jul 2025 15:40:28 +0800 Subject: [PATCH 04/20] refactor explorer with the new scheduler --- tests/explorer/scheduler_test.py | 104 ++++++++------- trinity/explorer/explorer.py | 217 +++++++++++++++---------------- trinity/explorer/scheduler.py | 134 ++++++++++--------- trinity/utils/monitor.py | 34 ++++- 4 files changed, 269 insertions(+), 220 deletions(-) diff --git a/tests/explorer/scheduler_test.py b/tests/explorer/scheduler_test.py index fa7fd99d85..e9323b50ed 100644 --- a/tests/explorer/scheduler_test.py +++ b/tests/explorer/scheduler_test.py @@ -24,12 +24,11 @@ def __init__(self, model, task, auxiliary_models): self.error_type = task.raw_task.get("error_type", "") self.seconds = None if "timeout" in self.error_type: - # 提取超时时间,格式如 "timeout_5" parts = self.error_type.split("_") if len(parts) > 1: self.seconds = int(parts[-1]) else: - self.seconds = 10 # 默认超时时间 + self.seconds = 10 def run(self) -> List[Experience]: if "timeout" in self.error_type: @@ -41,7 +40,6 @@ def run(self) -> List[Experience]: elif self.error_type == "auxiliary_models": assert self.auxiliary_models is not None and len(self.auxiliary_models) == 2 - # 返回一个成功的结果 return [ Experience( tokens=torch.zeros(5), prompt_length=2, prompt_text=self.error_type or "success" @@ -160,38 +158,38 @@ async def test_get_results(self): scheduler = Scheduler(self.config, [DummyModel.remote(), DummyModel.remote()]) await scheduler.start() - # tasks = generate_tasks(8) - # scheduler.schedule(tasks, step=0) + tasks = generate_tasks(8) + scheduler.schedule(tasks, batch_id=0) - # results = await scheduler.get_results(step=0, min_num=8, timeout=20) - # self.assertEqual(len(results), 8) + results = await scheduler.get_results(batch_id=0, min_num=8, timeout=20) + self.assertEqual(len(results), 8) - # for result in results: - # self.assertTrue(result.ok) + for result in results: + self.assertTrue(result.ok) - # for step in range(1, 4): - # tasks = generate_tasks(4) - # scheduler.schedule(tasks, step=step) + for batch_id in range(1, 4): + tasks = generate_tasks(4) + scheduler.schedule(tasks, batch_id=batch_id) - # for step in range(1, 4): - # self.assertTrue(scheduler.has_step(step)) - # results = await scheduler.get_results(step=step, min_num=4, timeout=10) - # self.assertEqual(len(results), 4) - # self.assertFalse(scheduler.has_step(step)) + for batch_id in range(1, 4): + self.assertTrue(scheduler.has_step(batch_id)) + results = await scheduler.get_results(batch_id=batch_id, min_num=4, timeout=10) + self.assertEqual(len(results), 4) + self.assertFalse(scheduler.has_step(batch_id)) - # tasks = generate_tasks(3) - # scheduler.schedule(tasks, step=4) - # self.assertTrue(scheduler.has_step(4)) - # results = await scheduler.get_results(step=4) - # self.assertEqual(len(results), 3) - # self.assertFalse(scheduler.has_step(4)) + tasks = generate_tasks(3) + scheduler.schedule(tasks, batch_id=4) + self.assertTrue(scheduler.has_step(4)) + results = await scheduler.get_results(batch_id=4) + self.assertEqual(len(results), 3) + self.assertFalse(scheduler.has_step(4)) # test timeout tasks = generate_tasks(2, timeout_num=2, timeout_seconds=10) - scheduler.schedule(tasks, step=0) + scheduler.schedule(tasks, batch_id=0) start_time = time.time() - results = await scheduler.get_results(step=0, min_num=4, timeout=3) + results = await scheduler.get_results(batch_id=0, min_num=4, timeout=3) end_time = time.time() self.assertLessEqual(end_time - start_time, 5) @@ -199,10 +197,10 @@ async def test_get_results(self): # test run tasks after timeout tasks = generate_tasks(4) - scheduler.schedule(tasks, step=0) + scheduler.schedule(tasks, batch_id=0) # actor restart is slow, set a big timeout - results = await scheduler.get_results(step=0, timeout=20) + results = await scheduler.get_results(batch_id=0, timeout=20) self.assertEqual(len(results), 4) success_count = sum(1 for r in results if r.ok) @@ -211,13 +209,21 @@ async def test_get_results(self): # test exception tasks tasks = generate_tasks(1, exception_num=3) - scheduler.schedule(tasks, step=1) - results = await scheduler.get_results(step=1, timeout=5) + scheduler.schedule(tasks, batch_id=1) + results = await scheduler.get_results(batch_id=1, timeout=5) self.assertEqual(len(results), 4) success_count = sum(1 for r in results if r.ok) self.assertEqual(success_count, 1) + # test clear_timeout_tasks + tasks = generate_tasks(3, timeout_num=1, timeout_seconds=3) + scheduler.schedule(tasks, batch_id=2) + results = await scheduler.get_results(batch_id=2, timeout=2, clear_timeout_tasks=False) + self.assertEqual(len(results), 3) + results = await scheduler.get_results(batch_id=2, timeout=2, clear_timeout_tasks=False) + self.assertEqual(len(results), 1) + await scheduler.stop() async def test_wait_all(self): @@ -227,8 +233,8 @@ async def test_wait_all(self): tasks1 = generate_tasks(4) tasks2 = generate_tasks(3) - scheduler.schedule(tasks1, step=0) - scheduler.schedule(tasks2, step=1) + scheduler.schedule(tasks1, batch_id=0) + scheduler.schedule(tasks2, batch_id=1) start_time = time.time() await scheduler.wait_all(timeout=10.0) @@ -239,14 +245,14 @@ async def test_wait_all(self): self.assertEqual(len(scheduler.pending_tasks), 0) self.assertEqual(len(scheduler.running_tasks), 0) - results0 = await scheduler.get_results(step=0, min_num=4, timeout=1) - results1 = await scheduler.get_results(step=1, min_num=3, timeout=1) + results0 = await scheduler.get_results(batch_id=0, min_num=4, timeout=1) + results1 = await scheduler.get_results(batch_id=1, min_num=3, timeout=1) self.assertEqual(len(results0), 4) self.assertEqual(len(results1), 3) # test timeout tasks = generate_tasks(2, timeout_num=2, timeout_seconds=10) - scheduler.schedule(tasks, step=0) + scheduler.schedule(tasks, batch_id=0) start_time = time.time() with self.assertRaises(TimeoutError): @@ -269,10 +275,10 @@ async def test_concurrent_operations(self): scheduler = Scheduler(self.config, [DummyModel.remote(), DummyModel.remote()]) await scheduler.start() - async def schedule_tasks(step, num_tasks): + async def schedule_tasks(batch_id, num_tasks): tasks = generate_tasks(num_tasks) - scheduler.schedule(tasks, step=step) - return await scheduler.get_results(step=step, min_num=num_tasks, timeout=10) + scheduler.schedule(tasks, batch_id=batch_id) + return await scheduler.get_results(batch_id=batch_id, min_num=num_tasks, timeout=10) results = await asyncio.gather( schedule_tasks(0, 3), @@ -291,15 +297,15 @@ async def test_scheduler_restart_after_stop(self): await scheduler.start() tasks = generate_tasks(2) - scheduler.schedule(tasks, step=0) - results = await scheduler.get_results(step=0, min_num=2, timeout=10) + scheduler.schedule(tasks, batch_id=0) + results = await scheduler.get_results(batch_id=0, min_num=2, timeout=10) self.assertEqual(len(results), 2) await scheduler.stop() await scheduler.start() tasks = generate_tasks(3) - scheduler.schedule(tasks, step=1) - results = await scheduler.get_results(step=1, min_num=3, timeout=10) + scheduler.schedule(tasks, batch_id=1) + results = await scheduler.get_results(batch_id=1, min_num=3, timeout=10) self.assertEqual(len(results), 3) await scheduler.stop() @@ -307,29 +313,29 @@ async def test_scheduler_all_methods(self): scheduler = Scheduler(self.config, [DummyModel.remote(), DummyModel.remote()]) await scheduler.start() tasks = generate_tasks(8) - scheduler.schedule(tasks, step=0) + scheduler.schedule(tasks, batch_id=0) self.assertTrue(scheduler.has_step(0)) - results = await scheduler.get_results(step=0, min_num=8, timeout=20) + results = await scheduler.get_results(batch_id=0, min_num=8, timeout=20) self.assertEqual(len(results), 8) - scheduler.schedule(tasks, step=1) - scheduler.schedule(tasks[:4], step=2) + scheduler.schedule(tasks, batch_id=1) + scheduler.schedule(tasks[:4], batch_id=2) self.assertFalse(scheduler.has_step(0)) - results = await scheduler.get_results(step=0, min_num=8) + results = await scheduler.get_results(batch_id=0, min_num=8) self.assertFalse(scheduler.has_step(0)) - self.assertEqual(len(results), 0) # step 0 has no more tasks + self.assertEqual(len(results), 0) # batch_id 0 has no more tasks self.assertFalse(scheduler.has_step(0)) self.assertTrue(scheduler.has_step(1)) self.assertTrue(scheduler.has_step(2)) await scheduler.wait_all() st = time.time() - results = await scheduler.get_results(step=1) + results = await scheduler.get_results(batch_id=1) et = time.time() self.assertTrue(et - st < 1.0) self.assertEqual(len(results), 8) self.assertFalse(scheduler.has_step(1)) self.assertTrue(scheduler.has_step(2)) st = time.time() - results = await scheduler.get_results(step=2) + results = await scheduler.get_results(batch_id=2) et = time.time() self.assertTrue(et - st < 1.0) self.assertEqual(len(results), 4) diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index d7e07406a2..c0ef426267 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -5,7 +5,7 @@ import asyncio import os import time -from collections import defaultdict +from collections import deque from typing import List, Optional import torch @@ -24,10 +24,10 @@ get_checkpoint_dir_with_step_num, load_state_dict, ) -from trinity.explorer.runner_pool import RunnerPool +from trinity.explorer.scheduler import Scheduler from trinity.manager.manager import CacheManager from trinity.utils.log import get_logger -from trinity.utils.monitor import MONITOR +from trinity.utils.monitor import MONITOR, gather_metrics class Explorer: @@ -38,6 +38,7 @@ def __init__(self, config: Config): self.cache = CacheManager(config) explorer_meta = self.cache.load_explorer() self.explore_step_num = explorer_meta.get("latest_iteration", 0) + self.last_sync_step = self.explore_step_num self.config = config self.algorithm_manager = AlgorithmManager(config) self.models, self.auxiliary_models = create_inference_models(config) @@ -51,7 +52,7 @@ def __init__(self, config: Config): self.taskset = get_buffer_reader( self.config.buffer.explorer_input.taskset, self.config.buffer ) - self.runner_pool = self._init_runner_pool() + self.scheduler = self._init_scheduler() self.monitor = MONITOR.get(self.config.monitor.monitor_type)( project=self.config.project, name=self.config.name, @@ -65,7 +66,7 @@ def __init__(self, config: Config): self.use_checkpoint_weights_update = ( self.config.synchronizer.sync_method == SyncMethod.CHECKPOINT ) - self.eval_explore_step_num = None + self.pending_eval_tasks = deque() # For checkpoint weights update # Use explorer to periodically load the latest model weights and @@ -110,20 +111,14 @@ async def setup_weight_sync_group( ] await asyncio.gather(*refs) - def _init_runner_pool(self) -> RunnerPool: + def _init_scheduler(self) -> Scheduler: if self.config.explorer.rollout_model.engine_type != "vllm_async": # sync model requires the same number of runners as the number of models - self.config.explorer.runner_num = self.config.explorer.rollout_model.engine_num + self.config.explorer.runner_per_model = 1 self.logger.info( "Sync vLLM model requires the same number of runners as the number of models" ) - if self.config.explorer.runner_num < self.config.explorer.rollout_model.engine_num: - self.config.explorer.runner_num = self.config.explorer.rollout_model.engine_num - self.logger.info( - f"Number of Runners is less than number of models, set to {self.config.explorer.runner_num}" - ) - self.logger.info(f"Setup {self.config.explorer.runner_num} WorkflowRunners") - return RunnerPool(self.config, self.models, self.auxiliary_models) + return Scheduler(self.config, self.models, self.auxiliary_models) async def _update_model_weight(self, step_num: int, state_dict: dict) -> None: # TODO: update model weight @@ -140,7 +135,7 @@ async def _update_model_weight(self, step_num: int, state_dict: dict) -> None: ) self.state_dict.clear() - async def _checkpoint_weights_update(self, step_num: Optional[int] = None) -> None: + async def _checkpoint_weights_update(self, step_num: Optional[int] = None) -> int: # TODO: support more checkpoint types try: checkpoint_dir, checkpoint_step_num = get_checkpoint_dir_with_step_num( @@ -149,12 +144,14 @@ async def _checkpoint_weights_update(self, step_num: Optional[int] = None) -> No step_num=step_num, ) if checkpoint_dir == self.old_checkpoint: - return + return checkpoint_step_num model_weights = load_state_dict(os.path.join(checkpoint_dir, "actor")) await self._update_model_weight(checkpoint_step_num, model_weights) self.old_checkpoint = checkpoint_dir + return checkpoint_step_num except Exception as e: self.logger.warning(f"Fail to load checkpoint: {e}") + return 0 async def _nccl_weights_update(self): assert self.state_dict_meta is not None @@ -164,9 +161,13 @@ async def _nccl_weights_update(self): async def prepare(self) -> None: """Preparation before running.""" + futures = [asyncio.create_task(self.scheduler.start())] if self.use_checkpoint_weights_update: master_address, master_port = await self.models[0].get_available_address.remote() - await self.setup_weight_sync_group(master_address, master_port) + futures.append( + asyncio.create_task(self.setup_weight_sync_group(master_address, master_port)) + ) + asyncio.gather(*futures, return_exceptions=True) async def get_weight(self, name: str) -> torch.Tensor: """Get the weight of the loaded model (For checkpoint weights update).""" @@ -174,25 +175,26 @@ async def get_weight(self, name: str) -> torch.Tensor: async def explore(self) -> str: """ - The dreamming loop for explorer and trainer. - | <----------------------------------------- one period ----------------------------------------------> | - explorer | <-- step_1 --> | <-- step_2 --> | ... | <-- step_n --> | <-- eval --> | <-- [idle] --> | <-- sync --> | - trainer | <-- idle --> | <-- step_1 --> | <-- step_2 --> | ... | <-- step_n --> | <-- [idle] --> | <-- sync --> | + The timeline of the exploration process: + explorer | <--------------------------------- one period -------------------------------------> | + | <------------------------------ eval -------------------------------> | <-- sync --> | + | <---------------- step_1 --------------> | | + | | <---------------- step_2 --------------> | | + | ... | + | | <---------------- step_n ---------------> | | + | | <---------------------- eval --------------------> | <-- sync --> | + trainer |--------------------------------------------------------------------------------------| + | <-- idle --> | <-- step_1 --> | <-- step_2 --> | ... | <-- step_n --> | <-- sync --> | """ - self.eval_explore_step_num = None while True: try: self.logger.info(f"Explore step {self.explore_step_num + 1} started.") - if ( - self.eval_explore_step_num is None - and self.explore_step_num % self.config.explorer.eval_interval == 0 - ): - self.eval_explore_step_num = self.explore_step_num - explore_contionue = self.explore_step() + explore_contionue = await self.explore_step() if not explore_contionue: break + if self.need_eval(): + self.eval() if self.need_sync(): - self.wait_for_workflow_done() await self.sync_weight() except Exception as e: self.logger.error(f"Error in Explorer: {e}") @@ -200,7 +202,7 @@ async def explore(self) -> str: self.logger.info("--------------------\n> Explorer finished.\n--------------------") return self.config.explorer.name - def explore_step(self) -> bool: + async def explore_step(self) -> bool: algo_config = self.algorithm_manager.get_current_algorithm_config(self.explore_step_num + 1) # skip warmup if algo_config.algorithm_type == "sft": @@ -210,15 +212,11 @@ def explore_step(self) -> bool: tasks = self.taskset.read() except StopIteration: self.logger.warning("No more tasks to explore. Stop exploring.") - self.cache.save_explorer( - current_step=self.explore_step_num, - current_task_index=self.explore_step_num * self.config.buffer.batch_size, - ) + await self.save_checkpoint(sync_weight=False) self.status = RunningStatus.STOPPED - self.wait_for_workflow_done() self.experience_buffer.release() return False - self.runner_pool.run_tasks(tasks) + self.scheduler.schedule(tasks, batch_id=self.explore_step_num + 1) self.explore_step_num += 1 return True @@ -229,59 +227,40 @@ def need_sync(self) -> bool: self.explore_step_num - self.config.synchronizer.sync_offset ) % self.config.synchronizer.sync_interval == 0 - def eval(self, eval_explore_step_num: int): + def need_eval(self) -> bool: + return self.explore_step_num % self.config.explorer.eval_interval == 0 + + def eval(self): """Evaluation on all evaluation data samples.""" if len(self.config.buffer.explorer_input.eval_tasksets) == 0: self.logger.warning("No evaluation data samples. Skip evaluation.") return - self.logger.info(f"Evaluation at step {eval_explore_step_num} started.") - all_st = time.time() - log_metrics = {} + self.logger.info(f"Evaluation at step {self.explore_step_num} started.") for eval_taskset_config in self.config.buffer.explorer_input.eval_tasksets: self.logger.info( - f"Evaluation on {eval_taskset_config.name} at step {eval_explore_step_num} started." + f"Evaluation on {eval_taskset_config.name} at step {self.explore_step_num} started." ) eval_taskset = get_buffer_reader(eval_taskset_config, self.config.buffer) - st = time.time() - all_metrics = defaultdict(list) - - def wait(): - status_list = self.runner_pool.get_next_unorder() - if not isinstance(status_list, list): - status_list = [status_list] - for status in status_list: - if not status.ok: - self.logger.error(f"Error when running task: {status.message}") - else: - for metric_name, metric_value in status.metric.items(): - all_metrics[metric_name].append(metric_value) - + eval_batch_id = f"{self.explore_step_num}/{eval_taskset.name}" + self.pending_eval_tasks.append((self.explore_step_num, eval_taskset.name)) while True: - if not self.runner_pool.has_free(): - wait() try: - self.runner_pool.run_tasks(eval_taskset.read()) + self.scheduler.schedule(eval_taskset.read(), batch_id=eval_batch_id) except StopIteration: break - while self.runner_pool.has_next(): - wait() - metrics = self.monitor.calculate_metrics(all_metrics, prefix=f"eval/{eval_taskset.name}") # type: ignore - log_metrics.update(metrics) - log_metrics[f"eval/{eval_taskset.name}/time"] = time.time() - st - log_metrics["eval/total_time"] = time.time() - all_st - self.monitor.log(log_metrics, step=eval_explore_step_num) # type: ignore - self.logger.info(f"Evaluation at step {eval_explore_step_num} finished.") async def benchmark(self) -> bool: """Benchmark the model checkpoints.""" # benchmark on the latest checkpoint if self.config.explorer.eval_on_latest_checkpoint: - await self._checkpoint_weights_update() - self.eval(self.explore_step_num) + self.explore_step_num = await self._checkpoint_weights_update() + self.eval() + await self._log_eval_metrics() return True # benchmark on base model - self.eval(0) + self.eval() + await self._log_eval_metrics() # benchmark on all checkoints all_ckp_steps = sorted( [ @@ -292,56 +271,69 @@ async def benchmark(self) -> bool: ] ) for step_num in all_ckp_steps: - await self._checkpoint_weights_update(step_num=step_num) - self.eval(step_num) + self.explore_step_num = await self._checkpoint_weights_update(step_num=step_num) + self.eval() + await self._log_eval_metrics() return True - def wait_for_workflow_done(self) -> None: - """Wait for workflow to finish.""" - all_metrics = defaultdict(list) - # wait for all tasks of this step to finish - while self.runner_pool.has_next(): - status_list = self.runner_pool.get_next_unorder() - if not isinstance(status_list, list): - status_list = [status_list] - for status in status_list: - if not status.ok: - self.logger.error(f"Error when running task: {status.message}") - # submit another task to replace the failed task - try: - tasks = self.taskset.read(batch_size=1) - except StopIteration: - self.logger.warning("No more tasks in taskset. Stop retrying.") - return - self.runner_pool.run_tasks(tasks) - else: - for metric_name, metric_value in status.metric.items(): - all_metrics[metric_name].append(metric_value) - # eval - if self.eval_explore_step_num is not None: - self.eval(self.eval_explore_step_num) - self.eval_explore_step_num = None - # calculate metrics - log_metrics = self.monitor.calculate_metrics(all_metrics, prefix="rollout") # type: ignore - self.monitor.log(log_metrics, step=self.explore_step_num) - self.logger.info(f"Explore step {self.explore_step_num} finished.") + async def save_checkpoint(self, sync_weight: bool = False) -> None: + # wait for all tasks to complete + self.logger.info("Waiting for all tasks to complete") + await self.scheduler.wait_all() + self.logger.info(f"All tasks before step {self.explore_step_num} have completed.") + await self._log_metrics(self.last_sync_step + 1, self.explore_step_num) + + if sync_weight: + # sync weights + self.logger.info(f"Explorer sync_weights at step {self.explore_step_num} started.") + self.status = RunningStatus.WAITING_SYNC + if self.use_checkpoint_weights_update: + await self._checkpoint_weights_update() + else: # nccl weights update + await self._nccl_weights_update() + self.status = RunningStatus.RUNNING + self.last_sync_step = self.explore_step_num + self.logger.info(f"Explorer sync_weights at step {self.explore_step_num} finished") - async def sync_weight(self) -> None: - """Synchronize model weights.""" - # call this method before training start to load the latest model weights - self.logger.info(f"Explorer sync weights at step {self.explore_step_num}.") - self.status = RunningStatus.WAITING_SYNC - if self.use_checkpoint_weights_update: - await self._checkpoint_weights_update() - else: # nccl weights update - await self._nccl_weights_update() # save explore checkpoint self.cache.save_explorer( current_step=self.explore_step_num, current_task_index=self.explore_step_num * self.config.buffer.batch_size, ) - self.status = RunningStatus.RUNNING - self.logger.info(f"Explorer sync at step {self.explore_step_num} finished") + + async def sync_weight(self) -> None: + """Synchronize model weights.""" + # call this method before training start to load the latest model weights + await self.save_checkpoint(sync_weight=True) + + async def _log_metrics(self, start_step: int, end_step: int) -> None: + for step in range(start_step, end_step + 1): + self.logger.info(f"Log metrics of step {step}") + await self._log_explore_metrics(step=step) + await self._log_eval_metrics(step=step) + + async def _log_explore_metrics(self, step: int) -> None: + results = await self.scheduler.get_results(batch_id=step) + metric = gather_metrics([status.metric for status in results], "rollout") + self.monitor.log(metric, step=step) + + async def _log_eval_metrics(self, step: Optional[int] = None) -> None: + if not self.pending_eval_tasks: + return + step = step or self.explore_step_num + st = time.time() + metric = {} + while self.pending_eval_tasks: + eval_step, eval_task_name = self.pending_eval_tasks[0] + if eval_step != step: + return + self.pending_eval_tasks.popleft() + eval_results = await self.scheduler.get_results(f"{step}/{eval_task_name}") + metric.update( + gather_metrics([status.metric for status in eval_results], f"eval/{eval_task_name}") + ) + metric["eval/total_time"] = time.time() - st + self.monitor.log(metric, step) async def running_status(self) -> RunningStatus: return self.status @@ -350,5 +342,6 @@ def flush_log(self, step: int) -> None: """Flush the log of the current step.""" self.monitor.log({}, step=step, commit=True) - def shutdown(self) -> None: + async def shutdown(self) -> None: self.monitor.close() + await self.scheduler.stop() diff --git a/trinity/explorer/scheduler.py b/trinity/explorer/scheduler.py index 119674bef1..f9c94eb117 100644 --- a/trinity/explorer/scheduler.py +++ b/trinity/explorer/scheduler.py @@ -4,7 +4,7 @@ import time import traceback from collections import defaultdict, deque -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union import ray @@ -110,11 +110,11 @@ def __init__( self.runner_num = len(rollout_model) * config.explorer.runner_per_model self.runners: Dict[int, RunnerWrapper] = dict() self.idle_runners = set() # runner_id - self.busy_runners = dict() # runner_id -> (task, step) + self.busy_runners = dict() # runner_id -> (task, batch_id) - self.pending_tasks: Dict[int, deque] = defaultdict(deque) # step -> tasks - self.running_tasks: Dict[int, set[asyncio.Future]] = defaultdict(set) # step -> futures - self.completed_tasks: Dict[int, deque[Status]] = defaultdict(deque) # step -> results + self.pending_tasks: Dict[str, deque] = defaultdict(deque) # batch_id -> tasks + self.running_tasks: Dict[str, set[asyncio.Future]] = defaultdict(set) # batch_id -> futures + self.completed_tasks: Dict[str, deque[Status]] = defaultdict(deque) # batch_id -> results self.scheduler_task: Optional[asyncio.Task] = None self.running = False @@ -145,7 +145,7 @@ def _restart_runner(self, runner_id: int): if runner_id in self.busy_runners: task, idx = self.busy_runners.pop(runner_id) self.logger.warning( - f"Runner {runner_id} failed to run task at step {idx}: {task.raw_task}" + f"Runner {runner_id} failed to run task at batch_id {idx}: {task.raw_task}" ) self.idle_runners.add(runner_id) @@ -167,52 +167,52 @@ async def _schedule_pending_tasks(self) -> None: if not self.idle_runners: return - for step in sorted(self.pending_tasks.keys()): - task_queue = self.pending_tasks[step] + for batch_id in sorted(self.pending_tasks.keys()): + task_queue = self.pending_tasks[batch_id] while task_queue and self.idle_runners: task = task_queue.pop() runner_id = self.idle_runners.pop() - self.busy_runners[runner_id] = (task, step) - self.running_tasks[step].add( + self.busy_runners[runner_id] = (task, batch_id) + self.running_tasks[batch_id].add( asyncio.create_task(self.runners[runner_id].run_with_retry(task)) ) if not task_queue: - del self.pending_tasks[step] + del self.pending_tasks[batch_id] async def _check_completed_tasks(self) -> None: - for step in list(self.running_tasks.keys()): - futures = self.running_tasks[step] + for batch_id in list(self.running_tasks.keys()): + futures = self.running_tasks[batch_id] for future in list(futures): if future.done(): futures.remove(future) try: task_result, runner_id = await future - self.completed_tasks[step].appendleft(task_result) + self.completed_tasks[batch_id].appendleft(task_result) self.busy_runners.pop(runner_id) self.idle_runners.add(runner_id) self.logger.debug( - f"Task completed (step {step}), success: {task_result.ok}" + f"Task completed (batch_id {batch_id}), success: {task_result.ok}" ) except Exception as e: self.logger.error(f"Error getting task result: {e}") if not futures: - del self.running_tasks[step] - - def _clear_timeout_tasks(self, step: int) -> None: - if step in self.pending_tasks: - self.logger.info(f"Clear timeout pending tasks at step {step}.") - del self.pending_tasks[step] - if step in self.running_tasks: - self.logger.info(f"Clear timeout running tasks at step {step}.") - for future in self.running_tasks[step]: + del self.running_tasks[batch_id] + + def _clear_timeout_tasks(self, batch_id: str) -> None: + if batch_id in self.pending_tasks: + self.logger.info(f"Clear timeout pending tasks at batch_id {batch_id}.") + del self.pending_tasks[batch_id] + if batch_id in self.running_tasks: + self.logger.info(f"Clear timeout running tasks at batch_id {batch_id}.") + for future in self.running_tasks[batch_id]: future.cancel() - del self.running_tasks[step] + del self.running_tasks[batch_id] async def start(self) -> None: if self.running: @@ -246,61 +246,70 @@ async def stop(self) -> None: pass self.logger.info("Scheduler stopped") - def schedule(self, tasks: List[Task], step: int) -> None: + def schedule(self, tasks: List[Task], batch_id: Union[int, str]) -> None: """Schedule the provided tasks. Args: tasks (`List[Task]`): The tasks to schedule. - step (`int`): The step number of provided tasks. + batch_id (`Union[int, str]`): The id of provided tasks. """ if not tasks: return + batch_id = str(batch_id) for task in tasks: - self.pending_tasks[step].appendleft(task) + self.pending_tasks[batch_id].appendleft(task) + self.logger.info(f"Scheduled {len(tasks)} tasks for batch {batch_id}") async def get_results( - self, step: int, min_num: Optional[int] = None, timeout: Optional[float] = None + self, + batch_id: Union[int, str], + min_num: Optional[int] = None, + timeout: Optional[float] = None, + clear_timeout_tasks: bool = True, ) -> List[Status]: - """Get the result of tasks at the specific step. + """Get the result of tasks at the specific batch_id. Args: - step (`int`): Only wait for tasks at this step. - min_num (`int`): The minimum number of tasks to wait for. If `None`, wait for all tasks at `step`. + batch_id (`Union[int, str]`): Only wait for tasks at this batch. + min_num (`int`): The minimum number of tasks to wait for. If `None`, wait for all tasks at `batch_id`. timeout (`float`): The timeout for waiting for tasks to finish. If `None`, wait for default timeout. + clear_timeout_tasks (`bool`): Whether to clear timeout tasks. """ timeout = timeout or self.timeout + batch_id = str(batch_id) start_time = time.time() if min_num is None: min_num = 0 - if step in self.pending_tasks: - min_num += len(self.pending_tasks[step]) - if step in self.running_tasks: - min_num += len(self.running_tasks[step]) - if step in self.completed_tasks: - min_num += len(self.completed_tasks[step]) + if batch_id in self.pending_tasks: + min_num += len(self.pending_tasks[batch_id]) + if batch_id in self.running_tasks: + min_num += len(self.running_tasks[batch_id]) + if batch_id in self.completed_tasks: + min_num += len(self.completed_tasks[batch_id]) self.logger.debug(f"Waiting for {min_num} tasks to complete...") while time.time() - start_time < timeout: - completed_count = len(self.completed_tasks[step]) + completed_count = len(self.completed_tasks[batch_id]) if completed_count >= min_num: break await asyncio.sleep(0.1) if time.time() - start_time > timeout: self.logger.error(f"Timed out waiting for tasks to complete after {timeout} seconds") - self._clear_timeout_tasks(step=step) - for runner_id in list(self.busy_runners.keys()): - if self.busy_runners[runner_id][1] == step: - self._restart_runner(runner_id) + if clear_timeout_tasks: + self._clear_timeout_tasks(batch_id=batch_id) + for runner_id in list(self.busy_runners.keys()): + if self.busy_runners[runner_id][1] == batch_id: + self._restart_runner(runner_id) results = [] for _ in range(min_num): - if len(self.completed_tasks[step]) > 0: - results.append(self.completed_tasks[step].pop()) + if len(self.completed_tasks[batch_id]) > 0: + results.append(self.completed_tasks[batch_id].pop()) - if not self.completed_tasks[step]: - del self.completed_tasks[step] + if not self.completed_tasks[batch_id]: + del self.completed_tasks[batch_id] completed_count = len(results) if completed_count < min_num: @@ -310,13 +319,23 @@ async def get_results( return results - def has_step(self, step: int) -> bool: + def has_step(self, batch_id: Union[int, str]) -> bool: + batch_id = str(batch_id) return ( - step in self.completed_tasks or step in self.pending_tasks or step in self.running_tasks + batch_id in self.completed_tasks + or batch_id in self.pending_tasks + or batch_id in self.running_tasks ) - async def wait_all(self, timeout: Optional[float] = None) -> None: - """Wait for all tasks to complete without poping results. If timeout reached, raise TimeoutError.""" + async def wait_all( + self, timeout: Optional[float] = None, clear_timeout_tasks: bool = True + ) -> None: + """Wait for all tasks to complete without poping results. If timeout reached, raise TimeoutError. + + Args: + timeout (`float`): timeout in seconds. + clear_timeout_tasks (`bool`): Whether to clear timeout tasks. + """ timeout = timeout or self.timeout start_time = time.time() @@ -336,17 +355,16 @@ async def wait_all(self, timeout: Optional[float] = None) -> None: self.logger.debug(f"Pending tasks: {pending_count}, Running tasks: {running_count}") await asyncio.sleep(0.1) - pending_count = sum(len(tasks) for tasks in self.pending_tasks.values()) running_count = sum(len(futures) for futures in self.running_tasks.values()) - for step in self.pending_tasks.keys() | self.running_tasks.keys(): - self._clear_timeout_tasks(step) - error_msg = f"Timeout after {timeout} seconds. Still have {pending_count} pending tasks and {running_count} running tasks." self.logger.error(error_msg) - busy_runner_ids = list(self.busy_runners.keys()) - for runner_id in busy_runner_ids: - self._restart_runner(runner_id) + if clear_timeout_tasks: + for batch_id in self.pending_tasks.keys() | self.running_tasks.keys(): + self._clear_timeout_tasks(batch_id) + busy_runner_ids = list(self.busy_runners.keys()) + for runner_id in busy_runner_ids: + self._restart_runner(runner_id) raise TimeoutError(error_msg) diff --git a/trinity/utils/monitor.py b/trinity/utils/monitor.py index e83df10b8f..113ee7d69e 100644 --- a/trinity/utils/monitor.py +++ b/trinity/utils/monitor.py @@ -2,7 +2,7 @@ import os from abc import ABC, abstractmethod -from typing import List, Optional, Union +from typing import Dict, List, Optional, Union import numpy as np import pandas as pd @@ -16,6 +16,18 @@ MONITOR = Registry("monitor") +def gather_metrics(metric_list: List[Dict], prefix: str) -> Dict: + df = pd.DataFrame(metric_list) + numeric_df = df.select_dtypes(include=[np.number]) + stats_df = numeric_df.agg(["mean", "max", "min"]) + metric = {} + for col in stats_df.columns: + metric[f"{prefix}/{col}/mean"] = stats_df.loc["mean", col] + metric[f"{prefix}/{col}/max"] = stats_df.loc["max", col] + metric[f"{prefix}/{col}/min"] = stats_df.loc["min", col] + return metric + + class Monitor(ABC): """Monitor""" @@ -111,3 +123,23 @@ def log(self, data: dict, step: int, commit: bool = False) -> None: def close(self) -> None: self.logger.finish() + + +if __name__ == "__main__": + metric_list = [ + { + "rollout/reward": 1.0, + "rollout/time": 12.0, + }, + { + "rollout/reward": 2.0, + "rollout/time": 13.0, + }, + { + "rollout/reward": 2.0, + }, + { + "rollout/time": 14.0, + }, + ] + print(gather_metrics(metric_list, prefix="explorer")) From 2aa5f3780bc02207fe16f9b548808d08926c31a3 Mon Sep 17 00:00:00 2001 From: pxc Date: Thu, 3 Jul 2025 17:31:22 +0800 Subject: [PATCH 05/20] fix tests --- tests/buffer/file_test.py | 2 + tests/explorer/runner_pool_test.py | 255 ---------------------- trinity/buffer/writer/file_writer.py | 2 +- trinity/explorer/__init__.py | 3 +- trinity/explorer/explorer.py | 10 +- trinity/explorer/runner_pool.py | 303 --------------------------- trinity/explorer/scheduler.py | 2 +- trinity/utils/monitor.py | 20 -- 8 files changed, 11 insertions(+), 586 deletions(-) delete mode 100644 tests/explorer/runner_pool_test.py delete mode 100644 trinity/explorer/runner_pool.py diff --git a/tests/buffer/file_test.py b/tests/buffer/file_test.py index 2882dd8e0f..4612d2ae7f 100644 --- a/tests/buffer/file_test.py +++ b/tests/buffer/file_test.py @@ -97,12 +97,14 @@ def test_file_writer(self): writer = get_buffer_writer( self.config.buffer.trainer_input.experience_buffer, self.config.buffer ) + writer.acquire() writer.write( [ {"prompt": "hello world"}, {"prompt": "hi"}, ] ) + writer.release() file_wrapper = ray.get_actor("json-test_buffer") self.assertIsNotNone(file_wrapper) file_path = default_storage_path( diff --git a/tests/explorer/runner_pool_test.py b/tests/explorer/runner_pool_test.py deleted file mode 100644 index 1ba7731efc..0000000000 --- a/tests/explorer/runner_pool_test.py +++ /dev/null @@ -1,255 +0,0 @@ -import copy -import os -import time -import unittest -from typing import List, Tuple - -import ray -import torch - -from tests.tools import get_unittest_dataset_config -from trinity.buffer.reader.queue_reader import QueueReader -from trinity.common.config import InferenceModelConfig, StorageConfig, load_config -from trinity.common.constants import StorageType -from trinity.common.experience import Experience -from trinity.common.models.model import InferenceModel -from trinity.common.workflows import Task -from trinity.common.workflows.workflow import WORKFLOWS, Workflow -from trinity.explorer.runner_pool import RunnerPool - -config_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "test_data", "template.yaml") - - -@WORKFLOWS.register_module("dummy_workflow") -class DummyWorkflow(Workflow): - def __init__(self, model, task, auxiliary_models): - super().__init__(model, task, auxiliary_models) - self.error_type = task.task_desc - self.seconds = None - if "timeout" in self.error_type: - self.seconds = int(self.error_type.split("_")[-1]) - - def run(self) -> List[Experience]: - if "timeout" in self.error_type: - time.sleep(self.seconds) - elif self.error_type == "exception": - raise ValueError("Exception occurred") - elif self.error_type == "exit": - exit(1) - elif self.error_type == "auxiliary_models": - assert self.auxiliary_models is not None and len(self.auxiliary_models) == 2 - return [Experience(tokens=torch.zeros(5), prompt_length=2, prompt_text=self.error_type)] - - -@ray.remote -class DummyModel(InferenceModel): - def sync_model(self, model_version, update_weight_args_list): - return True - - def get_model_version(self): - return 0 - - def init_process_group( - self, - master_address: str, - master_port: int, - rank_offset: int, - world_size: int, - group_name: str, - backend: str = "nccl", - timeout: int = 1200, - update_with_checkpoint: bool = True, - ) -> None: - pass - - -@ray.remote -class DummyAuxiliaryModel(InferenceModel): - def sync_model(self, model_version, update_weight_args_list): - return True - - def get_model_version(self): - return 0 - - def init_process_group( - self, - master_address: str, - master_port: int, - rank_offset: int, - world_size: int, - group_name: str, - backend: str = "nccl", - timeout: int = 1200, - update_with_checkpoint: bool = True, - ) -> None: - pass - - def has_api_server(self) -> bool: - return True - - def api_server_ready(self) -> Tuple[str, str]: - return "http://localhosts:12345", "placeholder" - - -class RunnerPoolTest(unittest.TestCase): - def setUp(self): - ray.init(ignore_reinit_error=True) - self.config = load_config(config_dir) - self.config.explorer.runner_num = 2 - self.config.explorer.max_retry_times = 0 - self.config.explorer.max_timeout = 5 - self.config.buffer.read_batch_size = 2 - self.config.buffer.pad_token_id = 0 - self.config.buffer.explorer_output = ( - self.config.buffer.trainer_input.experience_buffer - ) = StorageConfig( - name="test", - storage_type=StorageType.QUEUE, - algorithm_type="ppo", - path="", - ) - self.queue = QueueReader( - self.config.buffer.trainer_input.experience_buffer, self.config.buffer - ) - - def test_runner_pool(self): - pool = RunnerPool(self.config, [DummyModel.remote(), DummyModel.remote()]) - taskset_config = get_unittest_dataset_config("countdown") - tasks = [ - Task( - workflow=DummyWorkflow, - format_args=taskset_config.format, - rollout_args=taskset_config.rollout_args, - is_eval=False, - raw_task={ - taskset_config.format.prompt_key: "timeout_100", - }, - ), - Task( - workflow=DummyWorkflow, - format_args=taskset_config.format, - rollout_args=taskset_config.rollout_args, - is_eval=False, - raw_task={ - taskset_config.format.prompt_key: "exception", - }, - ), - Task( - workflow=DummyWorkflow, - format_args=taskset_config.format, - rollout_args=taskset_config.rollout_args, - is_eval=False, - raw_task={ - taskset_config.format.prompt_key: "timeout_2", - }, - ), - Task( - workflow=DummyWorkflow, - format_args=taskset_config.format, - rollout_args=taskset_config.rollout_args, - is_eval=False, - raw_task={ - taskset_config.format.prompt_key: "success", - }, - ), - Task( - workflow=DummyWorkflow, - format_args=taskset_config.format, - rollout_args=taskset_config.rollout_args, - is_eval=False, - raw_task={ - taskset_config.format.prompt_key: "timeout_101", - }, - ), - Task( - workflow=DummyWorkflow, - format_args=taskset_config.format, - rollout_args=taskset_config.rollout_args, - is_eval=False, - raw_task={ - taskset_config.format.prompt_key: "exit", - }, - ), - ] - - pool.run_tasks( - tasks=tasks, - ) - - # The excepted return order is: `exception` -> `timeout_2` -> `success` -> (`timeout_100`and `timeout_101`) -> `exit` - # 1. `exception` - st = time.time() - status = pool.get_next_unorder() - et = time.time() - self.assertTrue(et - st < 2) - print(f"First task use time: {et - st}") - self.assertEqual(len(status), 1) - self.assertFalse(status[0].ok) - # 2. `timeout_2 - st = time.time() - status = pool.get_next_unorder() - et = time.time() - self.assertTrue(et - st > 2) - self.assertEqual(len(status), 1) - self.assertTrue(status[0].ok) - # 3. `success` - st = time.time() - status = pool.get_next_unorder() - et = time.time() - self.assertTrue(et - st < 1) - self.assertEqual(len(status), 1) - self.assertTrue(status[0].ok) - # 4. `timeout_100`and `timeout_101` - st = time.time() - status = pool.get_next_unorder() - et = time.time() - self.assertTrue(et - st > 5) - self.assertEqual(len(status), 2) - self.assertFalse(status[0].ok) - self.assertFalse(status[1].ok) - - # 5.`exit` - status = pool.get_next_unorder() - self.assertEqual(len(status), 1) - self.assertFalse(status[0].ok) - - exps = self.queue.read() - self.assertEqual(len(exps), 2) # `timeout_2` and `success` - self.assertEqual(len(pool._idle_actors), self.config.explorer.runner_num) - - def test_runner_pool_with_auxiliary_models(self): - config = copy.deepcopy(self.config) - config.explorer.auxiliary_models = [ - InferenceModelConfig( - engine_num=1, - ), - InferenceModelConfig( - engine_num=1, - ), - ] - pool = RunnerPool( - config, - [DummyModel.remote(), DummyModel.remote()], - [[DummyAuxiliaryModel.remote()], [DummyAuxiliaryModel.remote()]], - ) - taskset_config = get_unittest_dataset_config("countdown") - tasks = [ - Task( - workflow=DummyWorkflow, - format_args=taskset_config.format, - rollout_args=taskset_config.rollout_args, - is_eval=False, - raw_task={ - taskset_config.format.prompt_key: "auxiliary_models", - }, - ), - ] - - pool.run_tasks( - tasks=tasks, - ) - - # `auxiliary_models` - status = pool.get_next_unorder() - self.assertEqual(len(status), 1) - self.assertTrue(status[0].ok) diff --git a/trinity/buffer/writer/file_writer.py b/trinity/buffer/writer/file_writer.py index 16ec96d0a9..b2dff825de 100644 --- a/trinity/buffer/writer/file_writer.py +++ b/trinity/buffer/writer/file_writer.py @@ -22,7 +22,7 @@ def write(self, data: List) -> None: def acquire(self) -> int: if self.wrap_in_ray: - return ray.get(self.writer.acquire()) + return ray.get(self.writer.acquire.remote()) else: return 0 diff --git a/trinity/explorer/__init__.py b/trinity/explorer/__init__.py index e7794c7cf6..8665a1b125 100644 --- a/trinity/explorer/__init__.py +++ b/trinity/explorer/__init__.py @@ -1,4 +1,3 @@ from trinity.explorer.explorer import Explorer -from trinity.explorer.runner_pool import RunnerPool -__all__ = ["Explorer", "RunnerPool"] +__all__ = ["Explorer"] diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index c0ef426267..7608e17460 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -5,6 +5,7 @@ import asyncio import os import time +import traceback from collections import deque from typing import List, Optional @@ -196,8 +197,8 @@ async def explore(self) -> str: self.eval() if self.need_sync(): await self.sync_weight() - except Exception as e: - self.logger.error(f"Error in Explorer: {e}") + except Exception: + self.logger.error(f"Error in Explorer: {traceback.format_exc()}") break self.logger.info("--------------------\n> Explorer finished.\n--------------------") return self.config.explorer.name @@ -314,8 +315,9 @@ async def _log_metrics(self, start_step: int, end_step: int) -> None: async def _log_explore_metrics(self, step: int) -> None: results = await self.scheduler.get_results(batch_id=step) - metric = gather_metrics([status.metric for status in results], "rollout") - self.monitor.log(metric, step=step) + if results: + metric = gather_metrics([status.metric for status in results], "rollout") + self.monitor.log(metric, step=step) async def _log_eval_metrics(self, step: Optional[int] = None) -> None: if not self.pending_eval_tasks: diff --git a/trinity/explorer/runner_pool.py b/trinity/explorer/runner_pool.py deleted file mode 100644 index e5ef8bdd5d..0000000000 --- a/trinity/explorer/runner_pool.py +++ /dev/null @@ -1,303 +0,0 @@ -"""Runner pool for running tasks in parallel. Modified from ray.util.actor_pool.ActorPool.""" -import random -from typing import List, Optional, Tuple, Union - -import ray - -from trinity.common.config import Config -from trinity.common.models.model import InferenceModel -from trinity.common.workflows import Task -from trinity.explorer.workflow_runner import Status, WorkflowRunner -from trinity.utils.log import get_logger - - -class RunnerPool: - """A pool of WorkflowRunner. - - The RunnerPool will automatically handle the exceptions during the workflow - and retry when the workflow fails or timeout. The number of max retries is - set in `config.explorer.max_retry_times` and the max timeout is set in - `config.explorer.max_timeout`. - """ - - def __init__( - self, - config: Config, - models: List[InferenceModel], - auxiliary_models: Optional[List[List[InferenceModel]]] = None, - ): - # actors to be used - self.logger = get_logger(__name__) - self.config = config - self.models = models - self.auxiliary_models = auxiliary_models or [] - self.timeout = config.explorer.max_timeout - self.max_retry_times = config.explorer.max_retry_times - - # get actor from future - self._future_to_actor = {} - - # get future from index - self._index_to_future = {} - - # next task to do - self._next_task_index = 0 - - # next task to return - self._next_return_index = 0 - - # next work depending when actors free - self._pending_submits = [] - - # create new actors - self.engine_status = [0] * config.explorer.rollout_model.engine_num - self.auxiliary_engine_status_list = [ - [0] * cfg.engine_num for cfg in config.explorer.auxiliary_models - ] - self._idle_actors = list() - self.actor_to_engine_index = {} - self._namespace = ray.get_runtime_context().namespace - self._create_actors(config.explorer.runner_num) - - def _create_actors(self, num: int = 1): - new_actors = [] - for _ in range(num): - engine_index = self.engine_status.index(min(self.engine_status)) - selected_auxiliary_models = [ - models[engine_status.index(min(engine_status))] - for models, engine_status in zip( - self.auxiliary_models, self.auxiliary_engine_status_list - ) - ] - new_actor = ( - ray.remote(WorkflowRunner) - .options( - namespace=self._namespace, - scheduling_strategy="SPREAD", - runtime_env={"env_vars": self.config.explorer.env_vars}, - ) - .remote( - self.config, - self.models[engine_index], - selected_auxiliary_models, - ) - ) - new_actors.append(new_actor) - self.engine_status[engine_index] += 1 - self.actor_to_engine_index[new_actor] = engine_index - for actor in new_actors: - self._return_actor(actor) - - def _kill_actors(self, actors): - if not isinstance(actors, list): - actors = [actors] - - for actor in actors: - release_engine_index = self.actor_to_engine_index[actor] - self.engine_status[release_engine_index] -= 1 - del self.actor_to_engine_index[actor] - ray.kill(actor) - - def _run_task(self, task: Task, retry_times: int = 0) -> None: - """Run a task in the pool. - - Arguments: - task: A task to run. - retry_times: The current retry times of the task. - """ - if self._idle_actors: - actor = self._idle_actors.pop() - future = actor.run_task.remote(task) - future_key = tuple(future) if isinstance(future, list) else future - self._future_to_actor[future_key] = (task, actor, retry_times) - self._index_to_future[self._next_task_index] = future - self._next_task_index += 1 - else: - self._pending_submits.append((task, retry_times)) - - def run_tasks(self, tasks: Union[List[Task], Task]) -> None: - """Schedule a list of tasks to run in the pool. - - Arguments: - tasks: A list of tasks. - """ - if isinstance(tasks, Task): - tasks = [tasks] - for task in tasks: - self._run_task(task, 0) - - def has_next(self): - """Returns whether there are any pending results to return. - - Returns: - True if there are any pending results not yet returned. - """ - return bool(self._future_to_actor) - - def _handle_single_future(self, future, is_timeout) -> Tuple[Status, Task, int]: - future_key = tuple(future) if isinstance(future, list) else future - t, a, r = self._future_to_actor.pop(future_key) - - if is_timeout: - # when timeout, restart the actor - self.logger.warning(f"Workflow {t.task_desc} Timeout.") - - # kill the actor and update engine status - self._kill_actors(a) - - # start a new actor - self._create_actors(num=1) - - return_status = Status( - False, metric={"time_per_task": self.timeout}, message="Workflow Timeout." - ) - else: - self._return_actor(a) - try: - return_status = ray.get(future) - except Exception as e: - self.logger.error(f"Error when running task: {e}") - return_status = Status( - False, - metric={"time_per_task": self.timeout}, - message=f"Error when running task: {e}", - ) - return return_status, t, r - - def get_next_unorder(self) -> List[Status]: - """Returns the next pending result unorder. - - Returns: - The return status of the next task. - """ - if not self.has_next(): - raise StopIteration("No more results to get") - is_timeout = False - res, _ = ray.wait(list(self._future_to_actor), num_returns=1, timeout=self.timeout) - if not res: - is_timeout = True - future_list = list(self._future_to_actor) - else: - future_list = res - - return_status_list = list() - for future in future_list: - return_status, t, r = self._handle_single_future(future, is_timeout) - - if not return_status.ok: - if r >= self.max_retry_times: - return_status_list.append( - Status( - False, - metric={"retry_times": r + 1}, - message=f"{return_status.message}\nWorkflow Retry Times Exceeded.", - ) - ) - else: - self.logger.info(f"Retry Workflow {t.task_desc}.") - self._run_task(t, r + 1) - else: - return_status_list.append(return_status) - - return return_status_list if return_status_list else self.get_next_unorder() - - # todo: this function may be discarded in the next version - def get_next(self) -> Status: - """Returns the next pending result in order. - - This returns the next task result, blocking for up to - the specified timeout until it is available. - - Returns: - The return status of the next task. - """ - if not self.has_next(): - raise StopIteration("No more results to get") - future = self._index_to_future[self._next_return_index] - is_timeout = False - res, _ = ray.wait([future], timeout=self.timeout) - if not res: - is_timeout = True - del self._index_to_future[self._next_return_index] - self._next_return_index += 1 - - future_key = tuple(future) if isinstance(future, list) else future - t, a, r = self._future_to_actor.pop(future_key) - - if is_timeout: - # when timeout, restart the actor - self.logger.warning(f"Workflow {t.task_desc} Timeout.") - ray.kill(a) - # TODO: balance the model - self._return_actor( - ray.remote(WorkflowRunner) - .options( - namespace=self._namespace, - scheduling_strategy="SPREAD", - ) - .remote( - self.config, - self.models[ - random.randint(0, self.config.explorer.rollout_model.engine_num - 1) - ], - ) - ) - return_status = Status( - False, metric={"time_per_task": self.timeout}, message="Workflow Timeout." - ) - else: - self._return_actor(a) - try: - return_status = ray.get(future) - except Exception as e: - self.logger.error(f"Error when running task: {e}") - return_status = Status( - False, - metric={"time_per_task": self.timeout}, - message=f"Error when running task: {e}", - ) - - if not return_status.ok: - if r >= self.max_retry_times: - return Status( - False, - metric={"retry_times": r + 1}, - message=f"{return_status.message}\nWorkflow Retry Times Exceeded.", - ) - else: - self.logger.info(f"Retry Workflow {t.task_desc}.") - self._run_task(t, r + 1) - return self.get_next() - else: - return return_status - - def _return_actor(self, actor): - try: - ray.get(actor.is_alive.remote()) - self._idle_actors.append(actor) - except Exception: - self.logger.info("The actor is not alive, restart a new actor") - self._kill_actors(actor) - self._create_actors(num=1) - - if self._pending_submits: - self._run_task(*self._pending_submits.pop(0)) - - def has_free(self): - """Returns whether there are any idle actors available. - - Returns: - True if there are any idle actors and no pending submits. - """ - return len(self._idle_actors) > 0 and len(self._pending_submits) == 0 - - def pop_idle(self): - """Removes an idle actor from the pool. - - Returns: - An idle actor if one is available. - None if no actor was free to be removed. - """ - if self.has_free(): - return self._idle_actors.pop() - return None diff --git a/trinity/explorer/scheduler.py b/trinity/explorer/scheduler.py index f9c94eb117..e8ec068447 100644 --- a/trinity/explorer/scheduler.py +++ b/trinity/explorer/scheduler.py @@ -167,6 +167,7 @@ async def _schedule_pending_tasks(self) -> None: if not self.idle_runners: return + # TODO: Support more advanced scheduling strategies for batch_id in sorted(self.pending_tasks.keys()): task_queue = self.pending_tasks[batch_id] @@ -258,7 +259,6 @@ def schedule(self, tasks: List[Task], batch_id: Union[int, str]) -> None: batch_id = str(batch_id) for task in tasks: self.pending_tasks[batch_id].appendleft(task) - self.logger.info(f"Scheduled {len(tasks)} tasks for batch {batch_id}") async def get_results( self, diff --git a/trinity/utils/monitor.py b/trinity/utils/monitor.py index 113ee7d69e..5896fc110d 100644 --- a/trinity/utils/monitor.py +++ b/trinity/utils/monitor.py @@ -123,23 +123,3 @@ def log(self, data: dict, step: int, commit: bool = False) -> None: def close(self) -> None: self.logger.finish() - - -if __name__ == "__main__": - metric_list = [ - { - "rollout/reward": 1.0, - "rollout/time": 12.0, - }, - { - "rollout/reward": 2.0, - "rollout/time": 13.0, - }, - { - "rollout/reward": 2.0, - }, - { - "rollout/time": 14.0, - }, - ] - print(gather_metrics(metric_list, prefix="explorer")) From 292795a90d23bf6f06806a83f99c0f6835d345b7 Mon Sep 17 00:00:00 2001 From: pxc Date: Thu, 3 Jul 2025 20:27:08 +0800 Subject: [PATCH 06/20] add docs --- .../source/tutorial/trinity_configs.md | 9 ++++- tests/template/config.yaml | 2 +- tests/trainer/trainer_test.py | 21 +++++----- trinity/common/config.py | 13 ++++--- trinity/explorer/explorer.py | 16 ++++---- trinity/explorer/scheduler.py | 38 ++++++++++++++----- trinity/manager/config_manager.py | 10 +++-- .../explorer_config_manager.py | 8 ++-- 8 files changed, 74 insertions(+), 43 deletions(-) diff --git a/docs/sphinx_doc/source/tutorial/trinity_configs.md b/docs/sphinx_doc/source/tutorial/trinity_configs.md index e96cc678b7..ab647501b0 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source/tutorial/trinity_configs.md @@ -313,7 +313,7 @@ Controls the rollout models and workflow execution. ```yaml explorer: name: explorer - runner_num: 32 + runner_per_model: 8 max_timeout: 900 max_retry_times: 2 env_vars: {} @@ -324,10 +324,12 @@ explorer: auxiliary_models: - model_path: /PATH/TO/MODEL tensor_parallel_size: 1 + eval_interval: 100 + eval_on_startup: True ``` - `name`: Name of the explorer. This name will be used as the Ray actor's name, so it must be unique. -- `runner_num`: Number of parallel workflow runners. +- `runner_per_model`: Number of parallel workflow runners per each rollout model. - `max_timeout`: Maximum time (in seconds) for a workflow to complete. - `max_retry_times`: Maximum number of retries for a workflow. - `env_vars`: Environment variables to be set for every workflow runners. @@ -335,6 +337,9 @@ explorer: - `rollout_model.engine_num`: Number of inference engines. - `rollout_model.tensor_parallel_size`: Degree of tensor parallelism. - `auxiliary_models`: Additional models used for custom workflows. +- `eval_interval`: Interval (in steps) for evaluating the model. +- `eval_on_startup`: Whether to evaluate the model on startup. More precisely, at step 0 with the original model, so it will not be triggered when restarting. +- `runner_num`: (*Deprecated*) Number of parallel workflow runners. --- diff --git a/tests/template/config.yaml b/tests/template/config.yaml index 98180fff48..aaca7ff0a8 100644 --- a/tests/template/config.yaml +++ b/tests/template/config.yaml @@ -37,7 +37,7 @@ buffer: default_reward_fn_type: '' explorer: eval_interval: 100 - runner_num: 4 + runner_per_model: 8 rollout_model: engine_type: vllm_async engine_num: 2 diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index 27aad9d8cc..67e44152df 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -60,7 +60,7 @@ def test_trainer(self): self.config.buffer.explorer_input.eval_tasksets.append( get_unittest_dataset_config("copy_countdown", "test") ) - self.config.trainer.save_interval = 4 + self.config.trainer.save_interval = 6 self.config.check_and_update() self.config.trainer.trainer_config.trainer.max_actor_ckpt_to_keep = 2 self.config.trainer.trainer_config.trainer.max_critic_ckpt_to_keep = 2 @@ -84,24 +84,25 @@ def test_trainer(self): self.assertEqual(parser.metric_max_step(response_metrics[0]), 8) ray.shutdown(_exiting_interpreter=True) # check checkpoint - checkpoint_step_4, _ = get_checkpoint_dir_with_step_num( + checkpoint_step_6, _ = get_checkpoint_dir_with_step_num( checkpoint_root_path=self.config.checkpoint_job_dir, trainer_type=self.config.trainer.trainer_type, - step_num=4, + step_num=6, ) - checkpoint_step_8, _ = get_checkpoint_dir_with_step_num( + # check save lastest checkpoint + checkpoint_step_8, step_num = get_checkpoint_dir_with_step_num( checkpoint_root_path=self.config.checkpoint_job_dir, trainer_type=self.config.trainer.trainer_type, - step_num=8, ) - self.assertTrue(os.path.exists(checkpoint_step_4)) - self.assertTrue(os.path.exists(checkpoint_step_8)) + self.assertTrue(len(os.listdir(os.path.join(checkpoint_step_6, "actor"))) > 0) + self.assertTrue(len(os.listdir(os.path.join(checkpoint_step_8, "actor"))) > 0) + self.assertEqual(step_num, 8) # TODO: Reinit will fail when using v1 engine, find a way to fix it ray.init(ignore_reinit_error=True) # test bench mode self.config.mode = "bench" self.config.synchronizer.sync_method = SyncMethod.CHECKPOINT - self.config.explorer.eval_on_latest_checkpoint = False + self.config.explorer.bench_on_latest_checkpoint = False self.config.check_and_update() bench(self.config) parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard")) @@ -116,7 +117,8 @@ def test_trainer(self): def tearDown(self): # remove dir only when the test passed - shutil.rmtree(self.config.checkpoint_job_dir) + # shutil.rmtree(self.config.checkpoint_job_dir) + pass class TestStepAheadAsyncRL(BaseTrainerCase): @@ -328,7 +330,6 @@ def test_fully_async_mode(self): config.cluster.node_num = 1 explorer1_config.explorer.rollout_model.engine_num = 1 explorer1_config.explorer.rollout_model.tensor_parallel_size = 1 - explorer1_config.explorer.runner_num = 4 explorer1_config.buffer.explorer_output = StorageConfig( name="exp_buffer", storage_type=StorageType.QUEUE, diff --git a/trinity/common/config.py b/trinity/common/config.py index a6d7eba036..6825342013 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -301,12 +301,12 @@ class ExplorerConfig: name: str = EXPLORER_NAME # for workflow runner # number of workflow runners. - # For sync engine (vllm), it should be equal to `engine_num`. - # For async engine (vllm_async), it can be larger than `engine_num`, e.g. 16 * `engine_num` - runner_num: int = 1 + # For sync engine (vllm), it should be `1`. + # For async engine (vllm_async), it could be a large number. + runner_per_model: int = 8 # number of runners per each rollout model max_timeout: int = 900 # wait each task for 15 minutes max_retry_times: int = 2 # retry each task for 2 times if it fails or timeout - runner_per_model: int = 8 + runner_num: Optional[int] = None # deprecated # for inference models # for rollout model @@ -316,7 +316,10 @@ class ExplorerConfig: # for evaluation eval_interval: int = 100 - eval_on_latest_checkpoint: bool = False + eval_on_startup: bool = True # evalulate at step 0 + + # for benchmark + bench_on_latest_checkpoint: bool = False # only benchmark the latest checkpoint @dataclass diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index 7608e17460..ed0c8fe886 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -39,7 +39,7 @@ def __init__(self, config: Config): self.cache = CacheManager(config) explorer_meta = self.cache.load_explorer() self.explore_step_num = explorer_meta.get("latest_iteration", 0) - self.last_sync_step = self.explore_step_num + self.last_sync_step = self.explore_step_num if self.explore_step_num > 0 else -1 self.config = config self.algorithm_manager = AlgorithmManager(config) self.models, self.auxiliary_models = create_inference_models(config) @@ -169,6 +169,8 @@ async def prepare(self) -> None: asyncio.create_task(self.setup_weight_sync_group(master_address, master_port)) ) asyncio.gather(*futures, return_exceptions=True) + if self.config.explorer.eval_on_startup and self.explore_step_num == 0: + self.eval() async def get_weight(self, name: str) -> torch.Tensor: """Get the weight of the loaded model (For checkpoint weights update).""" @@ -177,21 +179,21 @@ async def get_weight(self, name: str) -> torch.Tensor: async def explore(self) -> str: """ The timeline of the exploration process: - explorer | <--------------------------------- one period -------------------------------------> | - | <------------------------------ eval -------------------------------> | <-- sync --> | - | <---------------- step_1 --------------> | | + | <--------------------------------- one period -------------------------------------> | + explorer | <---------------- step_1 --------------> | | | | <---------------- step_2 --------------> | | | ... | | | <---------------- step_n ---------------> | | | | <---------------------- eval --------------------> | <-- sync --> | - trainer |--------------------------------------------------------------------------------------| - | <-- idle --> | <-- step_1 --> | <-- step_2 --> | ... | <-- step_n --> | <-- sync --> | + |--------------------------------------------------------------------------------------| + trainer | <-- idle --> | <-- step_1 --> | <-- step_2 --> | ... | <-- step_n --> | <-- sync --> | """ while True: try: self.logger.info(f"Explore step {self.explore_step_num + 1} started.") explore_contionue = await self.explore_step() if not explore_contionue: + # TODO: support eval on last checkpoint break if self.need_eval(): self.eval() @@ -253,7 +255,7 @@ def eval(self): async def benchmark(self) -> bool: """Benchmark the model checkpoints.""" # benchmark on the latest checkpoint - if self.config.explorer.eval_on_latest_checkpoint: + if self.config.explorer.bench_on_latest_checkpoint: self.explore_step_num = await self._checkpoint_weights_update() self.eval() await self._log_eval_metrics() diff --git a/trinity/explorer/scheduler.py b/trinity/explorer/scheduler.py index e8ec068447..1f5d07706b 100644 --- a/trinity/explorer/scheduler.py +++ b/trinity/explorer/scheduler.py @@ -1,6 +1,7 @@ """Scheduler for rollout tasks.""" import asyncio +import re import time import traceback from collections import defaultdict, deque @@ -89,6 +90,20 @@ def restart_runner(self): pass +def sort_batch_id(batch_id: Union[int, str]): + """Priority of batch_id""" + # TODO: avoid sort the batch_id every time + if isinstance(batch_id, int): + return (batch_id, 0) + else: + match = re.match(r"^(\d+)", batch_id) + if match: + num = int(match.group(1)) + return (num, 1) + else: + return (float("inf"), 1) + + class Scheduler: """Scheduler for rollout tasks.""" @@ -112,9 +127,14 @@ def __init__( self.idle_runners = set() # runner_id self.busy_runners = dict() # runner_id -> (task, batch_id) - self.pending_tasks: Dict[str, deque] = defaultdict(deque) # batch_id -> tasks - self.running_tasks: Dict[str, set[asyncio.Future]] = defaultdict(set) # batch_id -> futures - self.completed_tasks: Dict[str, deque[Status]] = defaultdict(deque) # batch_id -> results + self.pending_tasks_heap = [] + self.pending_tasks: Dict[Union[int, str], deque] = defaultdict(deque) # batch_id -> tasks + self.running_tasks: Dict[Union[int, str], set[asyncio.Future]] = defaultdict( + set + ) # batch_id -> futures + self.completed_tasks: Dict[Union[int, str], deque[Status]] = defaultdict( + deque + ) # batch_id -> results self.scheduler_task: Optional[asyncio.Task] = None self.running = False @@ -168,7 +188,7 @@ async def _schedule_pending_tasks(self) -> None: return # TODO: Support more advanced scheduling strategies - for batch_id in sorted(self.pending_tasks.keys()): + for batch_id in sorted(self.pending_tasks.keys(), key=sort_batch_id): task_queue = self.pending_tasks[batch_id] while task_queue and self.idle_runners: @@ -205,7 +225,7 @@ async def _check_completed_tasks(self) -> None: if not futures: del self.running_tasks[batch_id] - def _clear_timeout_tasks(self, batch_id: str) -> None: + def _clear_timeout_tasks(self, batch_id: Union[int, str]) -> None: if batch_id in self.pending_tasks: self.logger.info(f"Clear timeout pending tasks at batch_id {batch_id}.") del self.pending_tasks[batch_id] @@ -252,11 +272,11 @@ def schedule(self, tasks: List[Task], batch_id: Union[int, str]) -> None: Args: tasks (`List[Task]`): The tasks to schedule. - batch_id (`Union[int, str]`): The id of provided tasks. + batch_id (`Union[int, str]`): The id of provided tasks. It should be an integer or a string + starting with an integer (e.g., 123, "123/my_task") """ if not tasks: return - batch_id = str(batch_id) for task in tasks: self.pending_tasks[batch_id].appendleft(task) @@ -276,7 +296,6 @@ async def get_results( clear_timeout_tasks (`bool`): Whether to clear timeout tasks. """ timeout = timeout or self.timeout - batch_id = str(batch_id) start_time = time.time() if min_num is None: min_num = 0 @@ -320,7 +339,6 @@ async def get_results( return results def has_step(self, batch_id: Union[int, str]) -> bool: - batch_id = str(batch_id) return ( batch_id in self.completed_tasks or batch_id in self.pending_tasks @@ -353,8 +371,8 @@ async def wait_all( running_count = sum(len(futures) for futures in self.running_tasks.values()) self.logger.debug(f"Pending tasks: {pending_count}, Running tasks: {running_count}") - await asyncio.sleep(0.1) + pending_count = sum(len(tasks) for tasks in self.pending_tasks.values()) running_count = sum(len(futures) for futures in self.running_tasks.values()) error_msg = f"Timeout after {timeout} seconds. Still have {pending_count} pending tasks and {running_count} running tasks." diff --git a/trinity/manager/config_manager.py b/trinity/manager/config_manager.py index b9ba995985..b468382300 100644 --- a/trinity/manager/config_manager.py +++ b/trinity/manager/config_manager.py @@ -199,9 +199,11 @@ def _expert_buffer_part(self): def _expert_explorer_part(self): self.get_configs("sync_method", "sync_interval", "sync_timeout") - self.get_configs("runner_num", "max_timeout", "explorer_max_retry_times", "eval_interval") + self.get_configs( + "runner_per_model", "max_timeout", "explorer_max_retry_times", "eval_interval" + ) - self.get_configs("eval_on_latest_checkpoint") + self.get_configs("bench_on_latest_checkpoint") with st.expander("Rollout Model Config", expanded=True): self.get_configs("engine_type", "engine_num", "tensor_parallel_size") @@ -571,7 +573,7 @@ def _gen_buffer_config(self): def _gen_explorer_config(self): explorer_config = { - "runner_num": st.session_state["runner_num"], + "runner_per_model": st.session_state["runner_per_model"], "max_timeout": st.session_state["max_timeout"], "max_retry_times": st.session_state["explorer_max_retry_times"], "rollout_model": { @@ -584,7 +586,7 @@ def _gen_explorer_config(self): }, "auxiliary_models": [], "eval_interval": st.session_state["eval_interval"], - "eval_on_latest_checkpoint": st.session_state["eval_on_latest_checkpoint"], + "bench_on_latest_checkpoint": st.session_state["bench_on_latest_checkpoint"], } for i in range(st.session_state["_auxiliary_models_num"]): auxiliary_model_config = { diff --git a/trinity/manager/config_registry/explorer_config_manager.py b/trinity/manager/config_registry/explorer_config_manager.py index 12e8034a30..249c669f60 100644 --- a/trinity/manager/config_registry/explorer_config_manager.py +++ b/trinity/manager/config_registry/explorer_config_manager.py @@ -9,9 +9,9 @@ def explorer_visible() -> bool: return st.session_state["mode"] == "both" -@CONFIG_GENERATORS.register_config(default_value=32, visible=explorer_visible) -def set_runner_num(**kwargs): - st.number_input("Runner Num", min_value=1, **kwargs) +@CONFIG_GENERATORS.register_config(default_value=8, visible=explorer_visible) +def set_runner_per_model(**kwargs): + st.number_input("Runner per Model", min_value=1, **kwargs) @CONFIG_GENERATORS.register_config(default_value=900, visible=explorer_visible) @@ -30,7 +30,7 @@ def set_eval_interval(**kwargs): @CONFIG_GENERATORS.register_config(default_value=True, visible=explorer_visible) -def set_eval_on_latest_checkpoint(**kwargs): +def set_bench_on_latest_checkpoint(**kwargs): st.checkbox("Eval on Latest Checkpoint", **kwargs) From b63ee11075ba972c15d28701aabb479623e5bc99 Mon Sep 17 00:00:00 2001 From: pxc Date: Thu, 3 Jul 2025 20:30:23 +0800 Subject: [PATCH 07/20] fix trainer test --- tests/trainer/trainer_test.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index 67e44152df..00e406776b 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -60,7 +60,7 @@ def test_trainer(self): self.config.buffer.explorer_input.eval_tasksets.append( get_unittest_dataset_config("copy_countdown", "test") ) - self.config.trainer.save_interval = 6 + self.config.trainer.save_interval = 4 self.config.check_and_update() self.config.trainer.trainer_config.trainer.max_actor_ckpt_to_keep = 2 self.config.trainer.trainer_config.trainer.max_critic_ckpt_to_keep = 2 @@ -84,17 +84,17 @@ def test_trainer(self): self.assertEqual(parser.metric_max_step(response_metrics[0]), 8) ray.shutdown(_exiting_interpreter=True) # check checkpoint - checkpoint_step_6, _ = get_checkpoint_dir_with_step_num( + checkpoint_step_4, _ = get_checkpoint_dir_with_step_num( checkpoint_root_path=self.config.checkpoint_job_dir, trainer_type=self.config.trainer.trainer_type, - step_num=6, + step_num=4, ) # check save lastest checkpoint checkpoint_step_8, step_num = get_checkpoint_dir_with_step_num( checkpoint_root_path=self.config.checkpoint_job_dir, trainer_type=self.config.trainer.trainer_type, ) - self.assertTrue(len(os.listdir(os.path.join(checkpoint_step_6, "actor"))) > 0) + self.assertTrue(len(os.listdir(os.path.join(checkpoint_step_4, "actor"))) > 0) self.assertTrue(len(os.listdir(os.path.join(checkpoint_step_8, "actor"))) > 0) self.assertEqual(step_num, 8) # TODO: Reinit will fail when using v1 engine, find a way to fix it From 553b8619755ef6aa8333543b6cb4ed1639a7872a Mon Sep 17 00:00:00 2001 From: pxc Date: Thu, 3 Jul 2025 21:00:31 +0800 Subject: [PATCH 08/20] buffer writer support async --- tests/buffer/file_test.py | 16 +++++++++++----- tests/buffer/queue_test.py | 12 ++++++------ tests/buffer/sql_test.py | 12 ++++++------ tests/tools.py | 10 ++++++++++ trinity/buffer/buffer_writer.py | 8 ++++++-- trinity/buffer/queue.py | 2 +- trinity/buffer/writer/file_writer.py | 14 ++++++++++---- trinity/buffer/writer/queue_writer.py | 11 +++++++---- trinity/buffer/writer/sql_writer.py | 14 ++++++++++---- trinity/explorer/explorer.py | 6 ++++-- 10 files changed, 71 insertions(+), 34 deletions(-) diff --git a/tests/buffer/file_test.py b/tests/buffer/file_test.py index 4612d2ae7f..6a4248e2a1 100644 --- a/tests/buffer/file_test.py +++ b/tests/buffer/file_test.py @@ -16,7 +16,7 @@ from trinity.common.constants import StorageType -class TestFileBuffer(unittest.TestCase): +class TestFileBuffer(unittest.IsolatedAsyncioTestCase): temp_output_path = "tmp/test_file_buffer/" @classmethod @@ -93,25 +93,31 @@ def test_file_reader(self): break self.assertEqual(len(tasks), 16 * 3 - 20) - def test_file_writer(self): + async def test_file_writer(self): writer = get_buffer_writer( self.config.buffer.trainer_input.experience_buffer, self.config.buffer ) - writer.acquire() + await writer.acquire() writer.write( [ {"prompt": "hello world"}, {"prompt": "hi"}, ] ) - writer.release() + await writer.write_async( + [ + {"prompt": "My name is"}, + {"prompt": "What is your name?"}, + ] + ) + await writer.release() file_wrapper = ray.get_actor("json-test_buffer") self.assertIsNotNone(file_wrapper) file_path = default_storage_path( self.config.buffer.trainer_input.experience_buffer, self.config.buffer ) with open(file_path, "r") as f: - self.assertEqual(len(f.readlines()), 2) + self.assertEqual(len(f.readlines()), 4) def setUp(self): self.config = get_template_config() diff --git a/tests/buffer/queue_test.py b/tests/buffer/queue_test.py index 5819aeb462..8e88a69653 100644 --- a/tests/buffer/queue_test.py +++ b/tests/buffer/queue_test.py @@ -3,7 +3,7 @@ import torch -from tests.tools import RayUnittestBase +from tests.tools import RayUnittestBaseAysnc from trinity.buffer.reader.queue_reader import QueueReader from trinity.buffer.writer.queue_writer import QueueWriter from trinity.common.config import BufferConfig, StorageConfig @@ -13,8 +13,8 @@ BUFFER_FILE_PATH = os.path.join(os.path.dirname(__file__), "test_queue_buffer.jsonl") -class TestQueueBuffer(RayUnittestBase): - def test_queue_buffer(self): +class TestQueueBuffer(RayUnittestBaseAysnc): + async def test_queue_buffer(self): total_num = 8 put_batch_size = 2 read_batch_size = 4 @@ -32,7 +32,7 @@ def test_queue_buffer(self): ) writer = QueueWriter(meta, config) reader = QueueReader(meta, config) - self.assertEqual(writer.acquire(), 1) + self.assertEqual(await writer.acquire(), 1) exps = [ Experience( tokens=torch.tensor([float(j) for j in range(i + 1)]), @@ -43,7 +43,7 @@ def test_queue_buffer(self): for i in range(1, put_batch_size + 1) ] for _ in range(total_num // put_batch_size): - writer.write(exps) + await writer.write_async(exps) for _ in range(total_num // read_batch_size): exps = reader.read() self.assertEqual(len(exps), read_batch_size) @@ -62,7 +62,7 @@ def test_queue_buffer(self): ) exps = reader.read(batch_size=put_batch_size * 2) self.assertEqual(len(exps), put_batch_size * 2) - self.assertEqual(writer.release(), 0) + self.assertEqual(await writer.release(), 0) self.assertRaises(StopIteration, reader.read) with open(BUFFER_FILE_PATH, "r") as f: self.assertEqual(len(f.readlines()), total_num + put_batch_size * 2) diff --git a/tests/buffer/sql_test.py b/tests/buffer/sql_test.py index e40a91b4c7..22b1c739a6 100644 --- a/tests/buffer/sql_test.py +++ b/tests/buffer/sql_test.py @@ -1,9 +1,9 @@ import os -import unittest import ray import torch +from tests.tools import RayUnittestBaseAysnc from trinity.buffer.reader.sql_reader import SQLReader from trinity.buffer.writer.sql_writer import SQLWriter from trinity.common.config import BufferConfig, StorageConfig @@ -13,8 +13,8 @@ db_path = os.path.join(os.path.dirname(__file__), "test.db") -class TestSQLBuffer(unittest.TestCase): - def test_create_sql_buffer(self) -> None: +class TestSQLBuffer(RayUnittestBaseAysnc): + async def test_create_sql_buffer(self) -> None: total_num = 8 put_batch_size = 2 read_batch_size = 4 @@ -42,9 +42,9 @@ def test_create_sql_buffer(self) -> None: ) for i in range(1, put_batch_size + 1) ] - self.assertEqual(sql_writer.acquire(), 1) + self.assertEqual(await sql_writer.acquire(), 1) for _ in range(total_num // put_batch_size): - sql_writer.write(exps) + await sql_writer.write_async(exps) for _ in range(total_num // read_batch_size): exps = sql_reader.read() self.assertEqual(len(exps), read_batch_size) @@ -66,5 +66,5 @@ def test_create_sql_buffer(self) -> None: self.assertEqual(len(exps), put_batch_size * 2) db_wrapper = ray.get_actor("sql-test_buffer") self.assertIsNotNone(db_wrapper) - self.assertEqual(sql_writer.release(), 0) + self.assertEqual(await sql_writer.release(), 0) self.assertRaises(StopIteration, sql_reader.read) diff --git a/tests/tools.py b/tests/tools.py index 7be4a2c4ef..a9b2ca8349 100644 --- a/tests/tools.py +++ b/tests/tools.py @@ -182,3 +182,13 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): ray.shutdown(_exiting_interpreter=True) + + +class RayUnittestBaseAysnc(unittest.IsolatedAsyncioTestCase): + @classmethod + def setUpClass(cls): + ray.init(ignore_reinit_error=True, namespace="trinity_unittest") + + @classmethod + def tearDownClass(cls): + ray.shutdown(_exiting_interpreter=True) diff --git a/trinity/buffer/buffer_writer.py b/trinity/buffer/buffer_writer.py index 13079ffb76..3d3e939196 100644 --- a/trinity/buffer/buffer_writer.py +++ b/trinity/buffer/buffer_writer.py @@ -11,7 +11,11 @@ def write(self, data: List) -> None: """Write to buffer.""" @abstractmethod - def acquire(self) -> int: + async def write_async(self, data: List) -> None: + """Write to buffer asynchronously.""" + + @abstractmethod + async def acquire(self) -> int: """Acquire the buffer writer. Returns: @@ -19,7 +23,7 @@ def acquire(self) -> int: """ @abstractmethod - def release(self) -> int: + async def release(self) -> int: """Release the buffer writer. After release, the buffer writer can not be used again. Returns: diff --git a/trinity/buffer/queue.py b/trinity/buffer/queue.py index b49644e13c..ceddcfed91 100644 --- a/trinity/buffer/queue.py +++ b/trinity/buffer/queue.py @@ -57,7 +57,7 @@ async def release(self) -> int: self.ref_count -= 1 if self.ref_count <= 0: await self.queue.put(self.FINISH_MESSAGE) - self.writer.release() + await self.writer.release() return self.ref_count def length(self) -> int: diff --git a/trinity/buffer/writer/file_writer.py b/trinity/buffer/writer/file_writer.py index b2dff825de..93c10479ca 100644 --- a/trinity/buffer/writer/file_writer.py +++ b/trinity/buffer/writer/file_writer.py @@ -20,15 +20,21 @@ def write(self, data: List) -> None: else: self.writer.write(data) - def acquire(self) -> int: + async def write_async(self, data): if self.wrap_in_ray: - return ray.get(self.writer.acquire.remote()) + await self.writer.write.remote(data) + else: + self.writer.write(data) + + async def acquire(self) -> int: + if self.wrap_in_ray: + return await self.writer.acquire.remote() else: return 0 - def release(self) -> int: + async def release(self) -> int: if self.wrap_in_ray: - return ray.get(self.writer.release.remote()) + return await self.writer.release.remote() else: self.writer.release() return 0 diff --git a/trinity/buffer/writer/queue_writer.py b/trinity/buffer/writer/queue_writer.py index 7b12fab4c1..9b13262b80 100644 --- a/trinity/buffer/writer/queue_writer.py +++ b/trinity/buffer/writer/queue_writer.py @@ -23,8 +23,11 @@ def __init__(self, meta: StorageConfig, config: BufferConfig): def write(self, data: List) -> None: ray.get(self.queue.put_batch.remote(data)) - def acquire(self) -> int: - return ray.get(self.queue.acquire.remote()) + async def write_async(self, data): + return await self.queue.put_batch.remote(data) - def release(self) -> int: - return ray.get(self.queue.release.remote()) + async def acquire(self) -> int: + return await self.queue.acquire.remote() + + async def release(self) -> int: + return await self.queue.release.remote() diff --git a/trinity/buffer/writer/sql_writer.py b/trinity/buffer/writer/sql_writer.py index 95344d4447..a951201b80 100644 --- a/trinity/buffer/writer/sql_writer.py +++ b/trinity/buffer/writer/sql_writer.py @@ -23,15 +23,21 @@ def write(self, data: list) -> None: else: self.db_wrapper.write(data) - def acquire(self) -> int: + async def write_async(self, data): if self.wrap_in_ray: - return ray.get(self.db_wrapper.acquire.remote()) + await self.db_wrapper.write.remote(data) + else: + self.db_wrapper.write(data) + + async def acquire(self) -> int: + if self.wrap_in_ray: + return await self.db_wrapper.acquire.remote() else: return 0 - def release(self) -> int: + async def release(self) -> int: if self.wrap_in_ray: - return ray.get(self.db_wrapper.release.remote()) + return await self.db_wrapper.release.remote() else: self.db_wrapper.release() return 0 diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index ed0c8fe886..d954344aff 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -43,12 +43,12 @@ def __init__(self, config: Config): self.config = config self.algorithm_manager = AlgorithmManager(config) self.models, self.auxiliary_models = create_inference_models(config) + self.experience_buffer = None if self.config.mode != "bench": self.experience_buffer = get_buffer_writer( self.config.buffer.explorer_output, # type: ignore self.config.buffer, ) - self.experience_buffer.acquire() self.config.buffer.explorer_input.taskset.index = explorer_meta.get("latest_task_index", 0) self.taskset = get_buffer_reader( self.config.buffer.explorer_input.taskset, self.config.buffer @@ -169,6 +169,8 @@ async def prepare(self) -> None: asyncio.create_task(self.setup_weight_sync_group(master_address, master_port)) ) asyncio.gather(*futures, return_exceptions=True) + if self.experience_buffer: + await self.experience_buffer.acquire() if self.config.explorer.eval_on_startup and self.explore_step_num == 0: self.eval() @@ -217,7 +219,7 @@ async def explore_step(self) -> bool: self.logger.warning("No more tasks to explore. Stop exploring.") await self.save_checkpoint(sync_weight=False) self.status = RunningStatus.STOPPED - self.experience_buffer.release() + await self.experience_buffer.release() return False self.scheduler.schedule(tasks, batch_id=self.explore_step_num + 1) self.explore_step_num += 1 From 87d80d6a9dae9ff9412c3c97f2e469cbfd04dac9 Mon Sep 17 00:00:00 2001 From: pxc Date: Thu, 3 Jul 2025 21:03:04 +0800 Subject: [PATCH 09/20] fix tests --- tests/tools.py | 2 +- tests/trainer/trainer_test.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/tools.py b/tests/tools.py index a9b2ca8349..94127a462e 100644 --- a/tests/tools.py +++ b/tests/tools.py @@ -163,7 +163,7 @@ def metric_max_step(self, metric_name: str) -> int: def metric_steps(self, metric_name: str) -> List[int]: if not self.metric_exist(metric_name): raise ValueError(f"Metric '{metric_name}' does not exist.") - return list(self._metrics[metric_name].keys()) + return list(sorted(self._metrics[metric_name].keys())) def metric_values(self, metric_name: str) -> List: if not self.metric_exist(metric_name): diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index 00e406776b..2a8248f527 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -117,8 +117,7 @@ def test_trainer(self): def tearDown(self): # remove dir only when the test passed - # shutil.rmtree(self.config.checkpoint_job_dir) - pass + shutil.rmtree(self.config.checkpoint_job_dir) class TestStepAheadAsyncRL(BaseTrainerCase): From 7a12516f5624673de4dc6ec9300e066ff3e68ca1 Mon Sep 17 00:00:00 2001 From: pxc Date: Fri, 4 Jul 2025 11:01:10 +0800 Subject: [PATCH 10/20] fix tests --- tests/buffer/file_test.py | 5 +++-- tests/tools.py | 2 +- tests/trainer/trainer_test.py | 19 ++++++++++--------- trinity/cli/launcher.py | 1 + trinity/explorer/explorer.py | 17 ++++++++++------- 5 files changed, 25 insertions(+), 19 deletions(-) diff --git a/tests/buffer/file_test.py b/tests/buffer/file_test.py index 6a4248e2a1..d6233eb51f 100644 --- a/tests/buffer/file_test.py +++ b/tests/buffer/file_test.py @@ -30,7 +30,7 @@ def tearDownClass(cls): if os.path.exists(cls.temp_output_path): os.system(f"rm -rf {cls.temp_output_path}") - def test_file_buffer(self): + async def test_file_buffer(self): meta = StorageConfig( name="test_buffer", path=os.path.join(self.temp_output_path, "buffer.jsonl"), @@ -46,8 +46,9 @@ def test_file_buffer(self): # test writer writer = JSONWriter(meta, None) + await writer.acquire() writer.write(data) - writer.release() + await writer.release() # test reader meta.path = self.temp_output_path diff --git a/tests/tools.py b/tests/tools.py index 94127a462e..a9b2ca8349 100644 --- a/tests/tools.py +++ b/tests/tools.py @@ -163,7 +163,7 @@ def metric_max_step(self, metric_name: str) -> int: def metric_steps(self, metric_name: str) -> List[int]: if not self.metric_exist(metric_name): raise ValueError(f"Metric '{metric_name}' does not exist.") - return list(sorted(self._metrics[metric_name].keys())) + return list(self._metrics[metric_name].keys()) def metric_values(self, metric_name: str) -> List: if not self.metric_exist(metric_name): diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index 2a8248f527..170871fe81 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -106,14 +106,15 @@ def test_trainer(self): self.config.check_and_update() bench(self.config) parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard")) - countdown_metrics = parser.metric_list("eval/countdown") - copy_countdown_metrics = parser.metric_list("eval/copy_countdown") - self.assertTrue(len(countdown_metrics) > 0) - self.assertTrue(len(copy_countdown_metrics) > 0) - countdown_metric_steps = parser.metric_steps(countdown_metrics[0]) - countdown_copy_metric_steps = parser.metric_steps(copy_countdown_metrics[0]) - self.assertEqual([0, 4, 8], countdown_metric_steps) - self.assertEqual([0, 4, 8], countdown_copy_metric_steps) + for prefix in ["eval", "bench"]: + countdown_metrics = parser.metric_list(f"{prefix}/countdown") + copy_countdown_metrics = parser.metric_list(f"{prefix}/copy_countdown") + self.assertTrue(len(countdown_metrics) > 0) + self.assertTrue(len(copy_countdown_metrics) > 0) + countdown_metric_steps = parser.metric_steps(countdown_metrics[0]) + countdown_copy_metric_steps = parser.metric_steps(copy_countdown_metrics[0]) + self.assertEqual([0, 4, 8], countdown_metric_steps) + self.assertEqual([0, 4, 8], countdown_copy_metric_steps) def tearDown(self): # remove dir only when the test passed @@ -352,7 +353,7 @@ def test_fully_async_mode(self): explorer_process_1 = multiprocessing.Process(target=run_explorer, args=(explorer1_config,)) explorer_process_1.start() - time.sleep(20) + time.sleep(5) explorer2_config.explorer.name = "explorer2" explorer2_config.check_and_update() explorer_process_2 = multiprocessing.Process(target=run_explorer, args=(explorer2_config,)) diff --git a/trinity/cli/launcher.py b/trinity/cli/launcher.py index 1b3ba1f4bb..76830a125f 100644 --- a/trinity/cli/launcher.py +++ b/trinity/cli/launcher.py @@ -25,6 +25,7 @@ def bench(config: Config) -> None: """Evaluate model.""" + config.explorer.name = "benchmark" explorer = ( ray.remote(Explorer) .options( diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index d954344aff..2da5eb64bc 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -260,12 +260,13 @@ async def benchmark(self) -> bool: if self.config.explorer.bench_on_latest_checkpoint: self.explore_step_num = await self._checkpoint_weights_update() self.eval() - await self._log_eval_metrics() + await self._log_eval_metrics(prefix="bench") return True # benchmark on base model - self.eval() - await self._log_eval_metrics() + if self.config.explorer.eval_on_startup: + await self._log_eval_metrics(prefix="bench") + # benchmark on all checkoints all_ckp_steps = sorted( [ @@ -278,7 +279,7 @@ async def benchmark(self) -> bool: for step_num in all_ckp_steps: self.explore_step_num = await self._checkpoint_weights_update(step_num=step_num) self.eval() - await self._log_eval_metrics() + await self._log_eval_metrics(prefix="bench") return True async def save_checkpoint(self, sync_weight: bool = False) -> None: @@ -323,7 +324,7 @@ async def _log_explore_metrics(self, step: int) -> None: metric = gather_metrics([status.metric for status in results], "rollout") self.monitor.log(metric, step=step) - async def _log_eval_metrics(self, step: Optional[int] = None) -> None: + async def _log_eval_metrics(self, step: Optional[int] = None, prefix: str = "eval") -> None: if not self.pending_eval_tasks: return step = step or self.explore_step_num @@ -336,9 +337,11 @@ async def _log_eval_metrics(self, step: Optional[int] = None) -> None: self.pending_eval_tasks.popleft() eval_results = await self.scheduler.get_results(f"{step}/{eval_task_name}") metric.update( - gather_metrics([status.metric for status in eval_results], f"eval/{eval_task_name}") + gather_metrics( + [status.metric for status in eval_results], f"{prefix}/{eval_task_name}" + ) ) - metric["eval/total_time"] = time.time() - st + metric[f"{prefix}/total_time"] = time.time() - st self.monitor.log(metric, step) async def running_status(self) -> RunningStatus: From 120831f4d9630a8260b5778ba04fb2618aeeeba1 Mon Sep 17 00:00:00 2001 From: pxc Date: Fri, 4 Jul 2025 11:28:57 +0800 Subject: [PATCH 11/20] overlay log and weight sync --- trinity/explorer/explorer.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index 2da5eb64bc..cadb00913b 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -287,7 +287,9 @@ async def save_checkpoint(self, sync_weight: bool = False) -> None: self.logger.info("Waiting for all tasks to complete") await self.scheduler.wait_all() self.logger.info(f"All tasks before step {self.explore_step_num} have completed.") - await self._log_metrics(self.last_sync_step + 1, self.explore_step_num) + log_task = asyncio.create_task( + self._log_metrics(self.last_sync_step + 1, self.explore_step_num) + ) if sync_weight: # sync weights @@ -301,6 +303,9 @@ async def save_checkpoint(self, sync_weight: bool = False) -> None: self.last_sync_step = self.explore_step_num self.logger.info(f"Explorer sync_weights at step {self.explore_step_num} finished") + # overlay log and weight sync + await log_task + # save explore checkpoint self.cache.save_explorer( current_step=self.explore_step_num, From e41f53dbd376ad6ed76e9c160d016b9f0f74bffb Mon Sep 17 00:00:00 2001 From: pxc Date: Wed, 9 Jul 2025 11:34:55 +0800 Subject: [PATCH 12/20] fix explorer --- trinity/explorer/explorer.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index 33ef7e30b0..c7cca6a67b 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -173,6 +173,7 @@ async def _nccl_weights_update(self): await asyncio.gather( *[model.sync_model.remote(self.explore_step_num) for model in self.models] ) + self.status = RunningStatus.RUNNING async def ready_to_sync(self): async with self._ready_to_sync_condition: @@ -313,12 +314,10 @@ async def save_checkpoint(self, sync_weight: bool = False) -> None: if sync_weight: # sync weights self.logger.info(f"Explorer sync_weights at step {self.explore_step_num} started.") - self.status = RunningStatus.WAITING_SYNC if self.use_checkpoint_weights_update: await self._checkpoint_weights_update() else: # nccl weights update await self._nccl_weights_update() - self.status = RunningStatus.RUNNING self.last_sync_step = self.explore_step_num self.logger.info(f"Explorer sync_weights at step {self.explore_step_num} finished") @@ -328,16 +327,7 @@ async def save_checkpoint(self, sync_weight: bool = False) -> None: async def sync_weight(self) -> None: """Synchronize model weights.""" # call this method before training start to load the latest model weights - self.logger.info(f"Explorer sync weights at step {self.explore_step_num}.") - if self.use_checkpoint_weights_update: - await self._checkpoint_weights_update() - else: # nccl weights update - await self._nccl_weights_update() - # save explore checkpoint - self.cache.save_explorer( - current_step=self.explore_step_num, - current_task_index=self.explore_step_num * self.config.buffer.batch_size, - ) + await self.save_checkpoint(sync_weight=True) async def _log_metrics(self, start_step: int, end_step: int) -> None: for step in range(start_step, end_step + 1): From f9c95699549a33bfa1fc45588bdf1d093070f8f1 Mon Sep 17 00:00:00 2001 From: pxc Date: Wed, 9 Jul 2025 12:00:24 +0800 Subject: [PATCH 13/20] fix tests --- trinity/explorer/explorer.py | 12 +++++++----- trinity/trainer/trainer.py | 4 ---- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index c7cca6a67b..c22516be42 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -287,7 +287,7 @@ async def benchmark(self) -> bool: if self.config.explorer.eval_on_startup: await self._log_eval_metrics(prefix="bench") - # benchmark on all checkoints + # benchmark on all checkpoints all_ckp_steps = sorted( [ int(ckp.split("global_step_")[-1]) @@ -324,6 +324,12 @@ async def save_checkpoint(self, sync_weight: bool = False) -> None: # overlay log and weight sync await log_task + # save explore checkpoint + self.cache.save_explorer( + current_step=self.explore_step_num, + current_task_index=self.explore_step_num * self.config.buffer.batch_size, + ) + async def sync_weight(self) -> None: """Synchronize model weights.""" # call this method before training start to load the latest model weights @@ -364,10 +370,6 @@ async def _log_eval_metrics(self, step: Optional[int] = None, prefix: str = "eva async def running_status(self) -> RunningStatus: return self.status - def flush_log(self, step: int) -> None: - """Flush the log of the current step.""" - self.monitor.log({}, step=step, commit=True) - async def shutdown(self) -> None: self.monitor.close() await self.scheduler.stop() diff --git a/trinity/trainer/trainer.py b/trinity/trainer/trainer.py index 3a9f51f677..1378449cf2 100644 --- a/trinity/trainer/trainer.py +++ b/trinity/trainer/trainer.py @@ -73,10 +73,6 @@ def sync_weight(self) -> None: f"Trainer synchronizing weights at step {self.engine.train_step_num} end." ) - def flush_log(self, step: int) -> None: - """Flush the log of the current step.""" - self.engine.monitor.log({}, step=step, commit=True) - def shutdown(self) -> None: # if checkpoint not saved, save the last checkpoint step_num = self.engine.train_step_num From e9b1b78e40919dfa86951483d4e128bea6c225c2 Mon Sep 17 00:00:00 2001 From: pxc Date: Wed, 9 Jul 2025 14:04:38 +0800 Subject: [PATCH 14/20] fix env vars --- trinity/common/config.py | 2 ++ trinity/explorer/scheduler.py | 3 +++ 2 files changed, 5 insertions(+) diff --git a/trinity/common/config.py b/trinity/common/config.py index a88512d191..9ac865e436 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -310,6 +310,8 @@ class ExplorerConfig: runner_per_model: int = 8 # number of runners per each rollout model max_timeout: int = 900 # wait each task for 15 minutes max_retry_times: int = 2 # retry each task for 2 times if it fails or timeout + env_vars: dict = field(default_factory=dict) # environment variables for workflow runner + runner_num: Optional[int] = None # deprecated # for inference models diff --git a/trinity/explorer/scheduler.py b/trinity/explorer/scheduler.py index 1f5d07706b..39274cd988 100644 --- a/trinity/explorer/scheduler.py +++ b/trinity/explorer/scheduler.py @@ -42,6 +42,9 @@ def _create_runner(self): .options( namespace=self.namespace, scheduling_strategy="SPREAD", + runtime_env={ + "env_vars": self.config.explorer.env_vars, + }, ) .remote(self.config, self.rollout_model, self.auxiliary_models) ) From fb7ea23b94919a9075b503569c3be200791351cd Mon Sep 17 00:00:00 2001 From: pxc Date: Wed, 9 Jul 2025 14:13:38 +0800 Subject: [PATCH 15/20] add env vars examples --- examples/grpo_alfworld/alfworld.yaml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/grpo_alfworld/alfworld.yaml b/examples/grpo_alfworld/alfworld.yaml index 281008ae46..0f1146b4e9 100644 --- a/examples/grpo_alfworld/alfworld.yaml +++ b/examples/grpo_alfworld/alfworld.yaml @@ -44,6 +44,8 @@ explorer: seed: 42 gpu_memory_utilization: 0.7 enable_chunked_prefill: true + env_vars: + TMPDIR: /PATH/TO/ALFWORLD_TMP_DIR synchronizer: sync_method: 'nccl' sync_interval: 8 From d513b81403243c6988dd607e0808bcb4925bd5a1 Mon Sep 17 00:00:00 2001 From: pxc Date: Thu, 10 Jul 2025 10:11:09 +0800 Subject: [PATCH 16/20] test queue capacity --- tests/buffer/queue_test.py | 29 +++++++++++++++++++++++++++++ trinity/buffer/queue.py | 2 +- 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/tests/buffer/queue_test.py b/tests/buffer/queue_test.py index 8e88a69653..32b59534a7 100644 --- a/tests/buffer/queue_test.py +++ b/tests/buffer/queue_test.py @@ -1,4 +1,5 @@ import os +import threading import time import torch @@ -71,6 +72,34 @@ async def test_queue_buffer(self): et = time.time() self.assertTrue(et - st > 2) + # test queue capacity + meta = StorageConfig( + name="test_buffer_small", + algorithm_type="ppo", + storage_type=StorageType.QUEUE, + max_read_timeout=3, + capacity=4, + path=BUFFER_FILE_PATH, + ) + writer = QueueWriter(meta, config) + reader = QueueReader(meta, config) + writer.write([{"content": "hello"}]) + writer.write([{"content": "hi"}]) + writer.write([{"content": "hello"}]) + writer.write([{"content": "hi"}]) + + # should be blocked + def write_blocking_call(): + writer.write([{"content": "blocked"}]) + + thread = threading.Thread(target=write_blocking_call) + thread.start() + thread.join(timeout=2) + self.assertTrue(thread.is_alive(), "write() did not block as expected") + reader.read() + thread.join(timeout=1) + self.assertFalse(thread.is_alive()) + def setUp(self): if os.path.exists(BUFFER_FILE_PATH): os.remove(BUFFER_FILE_PATH) diff --git a/trinity/buffer/queue.py b/trinity/buffer/queue.py index ceddcfed91..534283c50b 100644 --- a/trinity/buffer/queue.py +++ b/trinity/buffer/queue.py @@ -28,7 +28,7 @@ class QueueActor: def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None: self.logger = get_logger(__name__) self.config = config - self.capacity = getattr(config, "capacity", 10000) + self.capacity = storage_config.capacity self.queue = asyncio.Queue(self.capacity) st_config = deepcopy(storage_config) st_config.wrap_in_ray = False From 777bb71dc9fbc94e2f823743c4de284347db2d7f Mon Sep 17 00:00:00 2001 From: pxc Date: Thu, 10 Jul 2025 10:38:30 +0800 Subject: [PATCH 17/20] fix scheduler default timeout --- trinity/explorer/scheduler.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/trinity/explorer/scheduler.py b/trinity/explorer/scheduler.py index 39274cd988..4c8015370a 100644 --- a/trinity/explorer/scheduler.py +++ b/trinity/explorer/scheduler.py @@ -121,7 +121,7 @@ def __init__( self.rollout_model = rollout_model self.auxiliary_models = auxiliary_models or [] self.namespace = ray.get_runtime_context().namespace - self.timeout = config.explorer.max_timeout + self.default_timeout = config.explorer.max_timeout * (config.explorer.max_retry_times + 1) self.max_retry_times = config.explorer.max_retry_times self.running = False @@ -298,7 +298,7 @@ async def get_results( timeout (`float`): The timeout for waiting for tasks to finish. If `None`, wait for default timeout. clear_timeout_tasks (`bool`): Whether to clear timeout tasks. """ - timeout = timeout or self.timeout + timeout = timeout or self.default_timeout start_time = time.time() if min_num is None: min_num = 0 @@ -357,7 +357,7 @@ async def wait_all( timeout (`float`): timeout in seconds. clear_timeout_tasks (`bool`): Whether to clear timeout tasks. """ - timeout = timeout or self.timeout + timeout = timeout or self.default_timeout start_time = time.time() self.logger.debug("Waiting for all tasks to complete...") From 63b77a040367d27b443b8542af07cf4ce85ac82b Mon Sep 17 00:00:00 2001 From: pxc Date: Thu, 10 Jul 2025 10:41:32 +0800 Subject: [PATCH 18/20] change default timeout --- trinity/common/config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/trinity/common/config.py b/trinity/common/config.py index 9ac865e436..70d147460f 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -308,7 +308,7 @@ class ExplorerConfig: # For sync engine (vllm), it should be `1`. # For async engine (vllm_async), it could be a large number. runner_per_model: int = 8 # number of runners per each rollout model - max_timeout: int = 900 # wait each task for 15 minutes + max_timeout: int = 1800 # wait each task for 30 minutes max_retry_times: int = 2 # retry each task for 2 times if it fails or timeout env_vars: dict = field(default_factory=dict) # environment variables for workflow runner @@ -368,7 +368,7 @@ class SynchronizerConfig: # allow explorer to run `sync_offset` steps before sync sync_offset: int = 0 # waiting for `sync_timeout` seconds before timeout in `nccl` method - sync_timeout: int = 1800 + sync_timeout: int = 3600 # wait for the lastest checkpoint to be ready # TODO: to be used wait_for_checkpoint: bool = False From a8fe49f830d8884f7770da2ee0c7d47f9d6c21a0 Mon Sep 17 00:00:00 2001 From: pxc Date: Thu, 10 Jul 2025 11:44:32 +0800 Subject: [PATCH 19/20] optimze wait_all timeout logic --- examples/grpo_alfworld/alfworld.yaml | 5 +++-- tests/explorer/scheduler_test.py | 24 ++++++++++++++++++++++++ trinity/explorer/scheduler.py | 9 +++++++-- 3 files changed, 34 insertions(+), 4 deletions(-) diff --git a/examples/grpo_alfworld/alfworld.yaml b/examples/grpo_alfworld/alfworld.yaml index 0f1146b4e9..f6079ad55e 100644 --- a/examples/grpo_alfworld/alfworld.yaml +++ b/examples/grpo_alfworld/alfworld.yaml @@ -34,6 +34,7 @@ buffer: path: 'sqlite:///alfworld.db' explorer: runner_num: 32 + max_timeout: 3600 rollout_model: engine_type: vllm_async engine_num: 2 @@ -48,8 +49,8 @@ explorer: TMPDIR: /PATH/TO/ALFWORLD_TMP_DIR synchronizer: sync_method: 'nccl' - sync_interval: 8 - sync_timeout: 1200 + sync_interval: 5 + sync_timeout: 3600 trainer: trainer_type: 'verl' trainer_config_path: 'examples/grpo_alfworld/train_alfworld.yaml' diff --git a/tests/explorer/scheduler_test.py b/tests/explorer/scheduler_test.py index e9323b50ed..0b96c98805 100644 --- a/tests/explorer/scheduler_test.py +++ b/tests/explorer/scheduler_test.py @@ -271,6 +271,30 @@ async def test_wait_all(self): self.assertLess(end_time - start_time, 1.0) await scheduler.stop() + async def test_wait_all_timeout_with_multi_batch(self): + self.config.explorer.max_timeout = 5 + self.config.explorer.rollout_model.engine_num = 4 + self.config.explorer.runner_per_model = 1 + + scheduler = Scheduler(self.config, [DummyModel.remote(), DummyModel.remote()]) + await scheduler.start() + + tasks = generate_tasks(1, timeout_num=3, timeout_seconds=3) + scheduler.schedule(tasks, batch_id=0) + tasks = generate_tasks(2, timeout_num=2, timeout_seconds=3) + scheduler.schedule(tasks, batch_id=1) + tasks = generate_tasks(3, timeout_num=1, timeout_seconds=3) + scheduler.schedule(tasks, batch_id=2) + start_time = time.time() + await scheduler.wait_all() + end_time = time.time() + self.assertTrue( + end_time - start_time > 9, + f"wait time should be greater than 9, but got {end_time - start_time}", + ) + + await scheduler.stop() + async def test_concurrent_operations(self): scheduler = Scheduler(self.config, [DummyModel.remote(), DummyModel.remote()]) await scheduler.start() diff --git a/trinity/explorer/scheduler.py b/trinity/explorer/scheduler.py index 4c8015370a..4203826d87 100644 --- a/trinity/explorer/scheduler.py +++ b/trinity/explorer/scheduler.py @@ -354,14 +354,14 @@ async def wait_all( """Wait for all tasks to complete without poping results. If timeout reached, raise TimeoutError. Args: - timeout (`float`): timeout in seconds. + timeout (`float`): timeout in seconds. Raise `TimeoutError` when no new tasks is completed within timeout. clear_timeout_tasks (`bool`): Whether to clear timeout tasks. """ timeout = timeout or self.default_timeout start_time = time.time() self.logger.debug("Waiting for all tasks to complete...") - + last_completed_count = 0 while time.time() - start_time < timeout: has_pending = bool(self.pending_tasks) has_running = bool(self.running_tasks) @@ -372,6 +372,11 @@ async def wait_all( pending_count = sum(len(tasks) for tasks in self.pending_tasks.values()) running_count = sum(len(futures) for futures in self.running_tasks.values()) + completed_count = sum(len(tasks) for tasks in self.completed_tasks.values()) + if completed_count != last_completed_count: + # flush timeout when new tasks are completed + start_time = time.time() + last_completed_count = completed_count self.logger.debug(f"Pending tasks: {pending_count}, Running tasks: {running_count}") await asyncio.sleep(0.1) From db01273fa8a18ab5215c42739c846a3ffcc7bd4e Mon Sep 17 00:00:00 2001 From: pxc Date: Thu, 10 Jul 2025 11:56:25 +0800 Subject: [PATCH 20/20] clean code --- trinity/explorer/scheduler.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/trinity/explorer/scheduler.py b/trinity/explorer/scheduler.py index 4203826d87..cc3bea2ca1 100644 --- a/trinity/explorer/scheduler.py +++ b/trinity/explorer/scheduler.py @@ -370,15 +370,12 @@ async def wait_all( self.logger.debug("All tasks completed successfully") return - pending_count = sum(len(tasks) for tasks in self.pending_tasks.values()) - running_count = sum(len(futures) for futures in self.running_tasks.values()) completed_count = sum(len(tasks) for tasks in self.completed_tasks.values()) if completed_count != last_completed_count: # flush timeout when new tasks are completed start_time = time.time() last_completed_count = completed_count - self.logger.debug(f"Pending tasks: {pending_count}, Running tasks: {running_count}") await asyncio.sleep(0.1) pending_count = sum(len(tasks) for tasks in self.pending_tasks.values())