|
| 1 | +"""Scheduler for rollout tasks.""" |
| 2 | + |
| 3 | +import asyncio |
| 4 | +import time |
| 5 | +from typing import List, Dict, Tuple, Optional |
| 6 | +from collections import defaultdict, deque |
| 7 | +import traceback |
| 8 | + |
| 9 | +import ray |
| 10 | + |
| 11 | +from trinity.common.models import InferenceModel |
| 12 | +from trinity.common.config import Config |
| 13 | +from trinity.common.workflows import Task |
| 14 | +from trinity.explorer.workflow_runner import WorkflowRunner, Status |
| 15 | +from trinity.utils.log import get_logger |
| 16 | + |
| 17 | + |
| 18 | + |
| 19 | +class RunnerWrapper: |
| 20 | + |
| 21 | + def __init__(self, runner: WorkflowRunner, runner_id: int): |
| 22 | + self.logger = get_logger(__name__) |
| 23 | + self.runner = runner |
| 24 | + self.runner_id = runner_id |
| 25 | + self.is_busy = False |
| 26 | + self.current_task: Task = None |
| 27 | + |
| 28 | + async def run_with_retry(self, task: Task, retry_times: int) -> Tuple[Status, int]: |
| 29 | + """ |
| 30 | + Returns: |
| 31 | + `Status`: The return status of the task. |
| 32 | + `int`: The runner_id of current runner. |
| 33 | + """ |
| 34 | + last_exception_msg = None |
| 35 | + self.is_busy = True |
| 36 | + self.current_task = task |
| 37 | + start_time = time.time() |
| 38 | + try: |
| 39 | + for attempt in range(retry_times + 1): |
| 40 | + try: |
| 41 | + status = await self.runner.run.remote(task) |
| 42 | + if status.ok: |
| 43 | + break |
| 44 | + else: |
| 45 | + self.logger.error(status.message) |
| 46 | + except Exception: |
| 47 | + last_exception_msg = traceback.format_exception() |
| 48 | + self.logger.warning( |
| 49 | + f"Task execution attempt {attempt + 1} failed:\n{last_exception_msg}" |
| 50 | + ) |
| 51 | + status = Status(ok=False, metric=dict(), message=last_exception_msg) |
| 52 | + finally: |
| 53 | + end_time = time.time() |
| 54 | + status.metric["task_run_time"] = end_time - start_time |
| 55 | + self.is_busy = False |
| 56 | + self.current_task = None |
| 57 | + return status, self.runner_id |
| 58 | + |
| 59 | + |
| 60 | +class Scheduler: |
| 61 | + """Scheduler for rollout tasks.""" |
| 62 | + |
| 63 | + def __init__( |
| 64 | + self, |
| 65 | + config: Config, |
| 66 | + rollout_model: List[InferenceModel], |
| 67 | + auxiliary_models: Optional[List[List[InferenceModel]]] = None, |
| 68 | + ): |
| 69 | + self.logger = get_logger(__name__) |
| 70 | + self.config = config |
| 71 | + self.rollout_model = rollout_model |
| 72 | + self.auxiliary_models = auxiliary_models or [] |
| 73 | + self.namespace = ray.get_runtime_context().namespace |
| 74 | + self.timeout = config.explorer.max_timeout |
| 75 | + self.max_retry_times = config.explorer.max_retry_times |
| 76 | + self.running = False |
| 77 | + |
| 78 | + self.runner_num = len(rollout_model) * config.explorer.runner_per_model |
| 79 | + self.runners: Dict[int, RunnerWrapper] = dict() |
| 80 | + self.idle_runners = set() |
| 81 | + self.busy_runners = dict() |
| 82 | + |
| 83 | + self.pending_tasks: Dict[int, deque] = defaultdict(deque) # step -> tasks |
| 84 | + self.running_tasks: Dict[int, set[asyncio.Future]] = defaultdict(set) # step -> futures |
| 85 | + self.completed_tasks: Dict[int, deque[Status]] = defaultdict(deque) # step -> results |
| 86 | + |
| 87 | + self.scheduler_task: Optional[asyncio.Task] = None |
| 88 | + self.running = False |
| 89 | + |
| 90 | + self.total_scheduled = 0 |
| 91 | + self.total_completed = 0 |
| 92 | + for i in range(self.runner_num): |
| 93 | + self._create_runner(i) |
| 94 | + |
| 95 | + async def _create_runner( |
| 96 | + self, |
| 97 | + runner_id: int, |
| 98 | + ) -> None: |
| 99 | + runner = RunnerWrapper( |
| 100 | + runner=( |
| 101 | + ray.remote(WorkflowRunner) |
| 102 | + .options( |
| 103 | + namespace=self.namespace, |
| 104 | + scheduling_strategy="SPREAD", |
| 105 | + ) |
| 106 | + .remote( |
| 107 | + self.config, |
| 108 | + self.rollout_model[runner_id % len(self.rollout_model)], |
| 109 | + [ |
| 110 | + self.auxiliary_models[j][runner_id % len(self.auxiliary_models[j])] |
| 111 | + for j in range(len(self.auxiliary_models)) |
| 112 | + ], |
| 113 | + ) |
| 114 | + ), |
| 115 | + runner_id=runner_id, |
| 116 | + ) |
| 117 | + self.runners[runner_id] = runner |
| 118 | + self.idle_runners.add(runner_id) |
| 119 | + |
| 120 | + def _restart_runner(self, runner_id: int): |
| 121 | + """Restart a runner.""" |
| 122 | + try: |
| 123 | + ray.kill(self.runners[runner_id]) |
| 124 | + except: |
| 125 | + pass |
| 126 | + |
| 127 | + self.create_runner(runner_id) |
| 128 | + |
| 129 | + |
| 130 | + async def _scheduler_loop(self) -> None: |
| 131 | + self.logger.info("Scheduler loop started.") |
| 132 | + while self.running: |
| 133 | + try: |
| 134 | + await self._schedule_pending_tasks() |
| 135 | + await self._check_completed_tasks() |
| 136 | + await asyncio.sleep(0.01) |
| 137 | + except Exception: |
| 138 | + self.logger.error(f"Error in scheduler loop:\n{traceback.format_exc()}") |
| 139 | + await asyncio.sleep(0.1) |
| 140 | + self.logger.info("Scheduler loop stopped.") |
| 141 | + |
| 142 | + async def _schedule_pending_tasks(self) -> None: |
| 143 | + if not self.idle_runners: |
| 144 | + return |
| 145 | + |
| 146 | + for step in sorted(self.pending_tasks.keys()): |
| 147 | + task_queue = self.pending_tasks[step] |
| 148 | + |
| 149 | + while task_queue and self.idle_runners: |
| 150 | + task = task_queue.pop() |
| 151 | + runner_id = self.idle_runners.pop() |
| 152 | + self.busy_runners[runner_id] = (task, step) |
| 153 | + self.running_tasks[step].add( |
| 154 | + asyncio.create_task(self.runners[runner_id].run_with_retry(task)) |
| 155 | + ) |
| 156 | + |
| 157 | + if not task_queue: |
| 158 | + del self.pending_tasks[step] |
| 159 | + |
| 160 | + async def _check_completed_tasks(self) -> None: |
| 161 | + for step in list(self.running_tasks.keys()): |
| 162 | + futures = self.running_tasks[step] |
| 163 | + |
| 164 | + for future in list(futures): |
| 165 | + if future.done(): |
| 166 | + futures.remove(future) |
| 167 | + try: |
| 168 | + task_result, runner_id = await future |
| 169 | + self.completed_tasks[step].appendleft(task_result) |
| 170 | + self.busy_runners.pop(runner_id) |
| 171 | + self.idle_runners.add(runner_id) |
| 172 | + |
| 173 | + self.logger.debug( |
| 174 | + f"Task completed (step {step}), success: {task_result.success}" |
| 175 | + ) |
| 176 | + |
| 177 | + except Exception as e: |
| 178 | + self.logger.error(f"Error getting task result: {e}") |
| 179 | + |
| 180 | + if not futures: |
| 181 | + del self.running_tasks[step] |
| 182 | + |
| 183 | + async def start(self) -> None: |
| 184 | + if self.running: |
| 185 | + return |
| 186 | + self.running = True |
| 187 | + await asyncio.gather([self._create_runner(i) for i in range(self.runner_num)]) |
| 188 | + self.scheduler_task = asyncio.create_task(self._scheduler_loop()) |
| 189 | + |
| 190 | + async def stop(self) -> None: |
| 191 | + if not self.running: |
| 192 | + return |
| 193 | + |
| 194 | + self.running = False |
| 195 | + all_running_futures = [] |
| 196 | + for futures in self.running_tasks.values(): |
| 197 | + all_running_futures.extend(futures) |
| 198 | + |
| 199 | + if all_running_futures: |
| 200 | + self.logger.info(f"Waiting for {len(all_running_futures)} running tasks to complete...") |
| 201 | + await asyncio.gather(*all_running_futures, return_exceptions=True) |
| 202 | + |
| 203 | + if self.scheduler_task: |
| 204 | + self.scheduler_task.cancel() |
| 205 | + try: |
| 206 | + await self.scheduler_task |
| 207 | + except asyncio.CancelledError: |
| 208 | + pass |
| 209 | + |
| 210 | + self.logger.info("Scheduler stopped") |
| 211 | + |
| 212 | + def schedule(self, tasks: List[Task], step: int) -> None: |
| 213 | + """Schedule the provided tasks. |
| 214 | +
|
| 215 | + Args: |
| 216 | + tasks (`List[Task]`): The tasks to schedule. |
| 217 | + step (`int`): The step number of provided tasks. |
| 218 | + """ |
| 219 | + if not tasks: |
| 220 | + return |
| 221 | + for task in tasks: |
| 222 | + self.pending_tasks[step].appendleft(task) |
| 223 | + |
| 224 | + |
| 225 | + async def get_results( |
| 226 | + self, step: int, min_num: Optional[int] = None, timeout: Optional[float] = None |
| 227 | + ) -> List[Dict]: |
| 228 | + """Get the result of tasks at the specific step. |
| 229 | +
|
| 230 | + Args: |
| 231 | + step (`int`): Only wait for tasks at this step. |
| 232 | + min_num (`int`): The minimum number of tasks to wait for. If `None`, wait for all tasks at `step`. |
| 233 | + timeout (`float`): The timeout for waiting for tasks to finish. If `None`, wait for default timeout. |
| 234 | + """ |
| 235 | + timeout = timeout or self.timeout |
| 236 | + start_time = time.time() |
| 237 | + if min_num is None: |
| 238 | + min_num = len(self.pending_tasks[step]) + len(self.running_tasks[step]) + len(self.completed_tasks[step]) |
| 239 | + self.logger.debug(f"Waiting for {min_num} tasks to complete...") |
| 240 | + |
| 241 | + while time.time() - start_time < timeout: |
| 242 | + completed_count = len(self.completed_tasks[step]) |
| 243 | + if completed_count >= min_num: |
| 244 | + break |
| 245 | + await asyncio.sleep(0.1) |
| 246 | + |
| 247 | + results = [] |
| 248 | + for _ in range(min_num): |
| 249 | + if len(self.completed_tasks[step]) > 0: |
| 250 | + results.append(self.completed_tasks[step].pop()) |
| 251 | + |
| 252 | + if not self.completed_tasks[step]: |
| 253 | + del self.completed_tasks[step] |
| 254 | + |
| 255 | + completed_count = len(results) |
| 256 | + if completed_count < min_num: |
| 257 | + self.logger.warning( |
| 258 | + f"Timeout reached, only {completed_count}/{min_num} tasks completed" |
| 259 | + ) |
| 260 | + |
| 261 | + return results |
0 commit comments