Skip to content

Commit 52adb78

Browse files
committed
add new scheduler
1 parent 531f38c commit 52adb78

File tree

2 files changed

+262
-1
lines changed

2 files changed

+262
-1
lines changed

trinity/common/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ class ExplorerConfig:
306306
runner_num: int = 1
307307
max_timeout: int = 900 # wait each task for 15 minutes
308308
max_retry_times: int = 2 # retry each task for 2 times if it fails or timeout
309-
env_vars: dict = field(default_factory=dict)
309+
runner_per_model: int = 8
310310

311311
# for inference models
312312
# for rollout model

trinity/explorer/scheduler.py

Lines changed: 261 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,261 @@
1+
"""Scheduler for rollout tasks."""
2+
3+
import asyncio
4+
import time
5+
from typing import List, Dict, Tuple, Optional
6+
from collections import defaultdict, deque
7+
import traceback
8+
9+
import ray
10+
11+
from trinity.common.models import InferenceModel
12+
from trinity.common.config import Config
13+
from trinity.common.workflows import Task
14+
from trinity.explorer.workflow_runner import WorkflowRunner, Status
15+
from trinity.utils.log import get_logger
16+
17+
18+
19+
class RunnerWrapper:
20+
21+
def __init__(self, runner: WorkflowRunner, runner_id: int):
22+
self.logger = get_logger(__name__)
23+
self.runner = runner
24+
self.runner_id = runner_id
25+
self.is_busy = False
26+
self.current_task: Task = None
27+
28+
async def run_with_retry(self, task: Task, retry_times: int) -> Tuple[Status, int]:
29+
"""
30+
Returns:
31+
`Status`: The return status of the task.
32+
`int`: The runner_id of current runner.
33+
"""
34+
last_exception_msg = None
35+
self.is_busy = True
36+
self.current_task = task
37+
start_time = time.time()
38+
try:
39+
for attempt in range(retry_times + 1):
40+
try:
41+
status = await self.runner.run.remote(task)
42+
if status.ok:
43+
break
44+
else:
45+
self.logger.error(status.message)
46+
except Exception:
47+
last_exception_msg = traceback.format_exception()
48+
self.logger.warning(
49+
f"Task execution attempt {attempt + 1} failed:\n{last_exception_msg}"
50+
)
51+
status = Status(ok=False, metric=dict(), message=last_exception_msg)
52+
finally:
53+
end_time = time.time()
54+
status.metric["task_run_time"] = end_time - start_time
55+
self.is_busy = False
56+
self.current_task = None
57+
return status, self.runner_id
58+
59+
60+
class Scheduler:
61+
"""Scheduler for rollout tasks."""
62+
63+
def __init__(
64+
self,
65+
config: Config,
66+
rollout_model: List[InferenceModel],
67+
auxiliary_models: Optional[List[List[InferenceModel]]] = None,
68+
):
69+
self.logger = get_logger(__name__)
70+
self.config = config
71+
self.rollout_model = rollout_model
72+
self.auxiliary_models = auxiliary_models or []
73+
self.namespace = ray.get_runtime_context().namespace
74+
self.timeout = config.explorer.max_timeout
75+
self.max_retry_times = config.explorer.max_retry_times
76+
self.running = False
77+
78+
self.runner_num = len(rollout_model) * config.explorer.runner_per_model
79+
self.runners: Dict[int, RunnerWrapper] = dict()
80+
self.idle_runners = set()
81+
self.busy_runners = dict()
82+
83+
self.pending_tasks: Dict[int, deque] = defaultdict(deque) # step -> tasks
84+
self.running_tasks: Dict[int, set[asyncio.Future]] = defaultdict(set) # step -> futures
85+
self.completed_tasks: Dict[int, deque[Status]] = defaultdict(deque) # step -> results
86+
87+
self.scheduler_task: Optional[asyncio.Task] = None
88+
self.running = False
89+
90+
self.total_scheduled = 0
91+
self.total_completed = 0
92+
for i in range(self.runner_num):
93+
self._create_runner(i)
94+
95+
async def _create_runner(
96+
self,
97+
runner_id: int,
98+
) -> None:
99+
runner = RunnerWrapper(
100+
runner=(
101+
ray.remote(WorkflowRunner)
102+
.options(
103+
namespace=self.namespace,
104+
scheduling_strategy="SPREAD",
105+
)
106+
.remote(
107+
self.config,
108+
self.rollout_model[runner_id % len(self.rollout_model)],
109+
[
110+
self.auxiliary_models[j][runner_id % len(self.auxiliary_models[j])]
111+
for j in range(len(self.auxiliary_models))
112+
],
113+
)
114+
),
115+
runner_id=runner_id,
116+
)
117+
self.runners[runner_id] = runner
118+
self.idle_runners.add(runner_id)
119+
120+
def _restart_runner(self, runner_id: int):
121+
"""Restart a runner."""
122+
try:
123+
ray.kill(self.runners[runner_id])
124+
except:
125+
pass
126+
127+
self.create_runner(runner_id)
128+
129+
130+
async def _scheduler_loop(self) -> None:
131+
self.logger.info("Scheduler loop started.")
132+
while self.running:
133+
try:
134+
await self._schedule_pending_tasks()
135+
await self._check_completed_tasks()
136+
await asyncio.sleep(0.01)
137+
except Exception:
138+
self.logger.error(f"Error in scheduler loop:\n{traceback.format_exc()}")
139+
await asyncio.sleep(0.1)
140+
self.logger.info("Scheduler loop stopped.")
141+
142+
async def _schedule_pending_tasks(self) -> None:
143+
if not self.idle_runners:
144+
return
145+
146+
for step in sorted(self.pending_tasks.keys()):
147+
task_queue = self.pending_tasks[step]
148+
149+
while task_queue and self.idle_runners:
150+
task = task_queue.pop()
151+
runner_id = self.idle_runners.pop()
152+
self.busy_runners[runner_id] = (task, step)
153+
self.running_tasks[step].add(
154+
asyncio.create_task(self.runners[runner_id].run_with_retry(task))
155+
)
156+
157+
if not task_queue:
158+
del self.pending_tasks[step]
159+
160+
async def _check_completed_tasks(self) -> None:
161+
for step in list(self.running_tasks.keys()):
162+
futures = self.running_tasks[step]
163+
164+
for future in list(futures):
165+
if future.done():
166+
futures.remove(future)
167+
try:
168+
task_result, runner_id = await future
169+
self.completed_tasks[step].appendleft(task_result)
170+
self.busy_runners.pop(runner_id)
171+
self.idle_runners.add(runner_id)
172+
173+
self.logger.debug(
174+
f"Task completed (step {step}), success: {task_result.success}"
175+
)
176+
177+
except Exception as e:
178+
self.logger.error(f"Error getting task result: {e}")
179+
180+
if not futures:
181+
del self.running_tasks[step]
182+
183+
async def start(self) -> None:
184+
if self.running:
185+
return
186+
self.running = True
187+
await asyncio.gather([self._create_runner(i) for i in range(self.runner_num)])
188+
self.scheduler_task = asyncio.create_task(self._scheduler_loop())
189+
190+
async def stop(self) -> None:
191+
if not self.running:
192+
return
193+
194+
self.running = False
195+
all_running_futures = []
196+
for futures in self.running_tasks.values():
197+
all_running_futures.extend(futures)
198+
199+
if all_running_futures:
200+
self.logger.info(f"Waiting for {len(all_running_futures)} running tasks to complete...")
201+
await asyncio.gather(*all_running_futures, return_exceptions=True)
202+
203+
if self.scheduler_task:
204+
self.scheduler_task.cancel()
205+
try:
206+
await self.scheduler_task
207+
except asyncio.CancelledError:
208+
pass
209+
210+
self.logger.info("Scheduler stopped")
211+
212+
def schedule(self, tasks: List[Task], step: int) -> None:
213+
"""Schedule the provided tasks.
214+
215+
Args:
216+
tasks (`List[Task]`): The tasks to schedule.
217+
step (`int`): The step number of provided tasks.
218+
"""
219+
if not tasks:
220+
return
221+
for task in tasks:
222+
self.pending_tasks[step].appendleft(task)
223+
224+
225+
async def get_results(
226+
self, step: int, min_num: Optional[int] = None, timeout: Optional[float] = None
227+
) -> List[Dict]:
228+
"""Get the result of tasks at the specific step.
229+
230+
Args:
231+
step (`int`): Only wait for tasks at this step.
232+
min_num (`int`): The minimum number of tasks to wait for. If `None`, wait for all tasks at `step`.
233+
timeout (`float`): The timeout for waiting for tasks to finish. If `None`, wait for default timeout.
234+
"""
235+
timeout = timeout or self.timeout
236+
start_time = time.time()
237+
if min_num is None:
238+
min_num = len(self.pending_tasks[step]) + len(self.running_tasks[step]) + len(self.completed_tasks[step])
239+
self.logger.debug(f"Waiting for {min_num} tasks to complete...")
240+
241+
while time.time() - start_time < timeout:
242+
completed_count = len(self.completed_tasks[step])
243+
if completed_count >= min_num:
244+
break
245+
await asyncio.sleep(0.1)
246+
247+
results = []
248+
for _ in range(min_num):
249+
if len(self.completed_tasks[step]) > 0:
250+
results.append(self.completed_tasks[step].pop())
251+
252+
if not self.completed_tasks[step]:
253+
del self.completed_tasks[step]
254+
255+
completed_count = len(results)
256+
if completed_count < min_num:
257+
self.logger.warning(
258+
f"Timeout reached, only {completed_count}/{min_num} tasks completed"
259+
)
260+
261+
return results

0 commit comments

Comments
 (0)