Skip to content

Commit e322d3c

Browse files
author
chucai.dzq
committed
impl rollout controller
1 parent 008bd7a commit e322d3c

File tree

7 files changed

+236
-76
lines changed

7 files changed

+236
-76
lines changed

areal/api/controller_api.py

Lines changed: 4 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -458,7 +458,6 @@ def forward(
458458
"""
459459
raise NotImplementedError()
460460

461-
462461
class RolloutController(abc.ABC):
463462
"""A centralized controller that manages multiple distributed InferenceEngine workers for rollout generation.
464463
@@ -508,21 +507,6 @@ def destroy(self):
508507
"""Destroy the engine and release GPU memory for the local inference engine."""
509508
raise NotImplementedError()
510509

511-
async def agenerate(self, req: ModelRequest) -> ModelResponse:
512-
"""Asynchronously generate a response for the given request.
513-
514-
Parameters
515-
----------
516-
req : ModelRequest
517-
The model request containing input data and generation parameters
518-
519-
Returns
520-
-------
521-
ModelResponse
522-
The generated response from the model
523-
"""
524-
raise NotImplementedError()
525-
526510
def update_weights(self, meta: WeightUpdateMeta) -> Future:
527511
"""Update weights in the inference engine in a non-blocking manner.
528512
@@ -571,7 +555,7 @@ def get_version(self) -> int:
571555

572556
def submit(
573557
self,
574-
data: Dict[str, Any],
558+
data: DistributedBatch,
575559
workflow: Optional["RolloutWorkflow"] = None,
576560
workflow_builder: Optional[Callable] = None,
577561
should_accept: Callable | None = None,
@@ -623,7 +607,7 @@ def wait(self, count: int, timeout: float | None = None) -> DistributedBatch:
623607

624608
def rollout_batch(
625609
self,
626-
data: List[Dict[str, Any]],
610+
data: DistributedBatch,
627611
workflow: Optional["RolloutWorkflow"] = None,
628612
workflow_builder: Optional[Callable] = None,
629613
should_accept: Callable | None = None,
@@ -652,7 +636,7 @@ def rollout_batch(
652636

653637
def prepare_batch(
654638
self,
655-
dataloader: StatefulDataLoader,
639+
dataloader: DistributedBatch,
656640
workflow: Optional["RolloutWorkflow"] = None,
657641
workflow_builder: Optional[Callable] = None,
658642
should_accept: Callable | None = None,
@@ -688,31 +672,4 @@ def pause(self):
688672

689673
def resume(self):
690674
"""Resume request submission for async rollout."""
691-
raise NotImplementedError()
692-
693-
def register_callback_to_all_worker(
694-
self, method: str, callback: Callable, **kwargs
695-
):
696-
"""Register a callback function for the specified method across all workers.
697-
698-
Partial rollout API. After successful registration, the controller will poll
699-
and call the specified method in a background thread. When the return value
700-
is obtained, it will be used as a parameter to call the `callback` function.
701-
702-
Parameters
703-
----------
704-
method : str
705-
The name of the method to register the callback for
706-
callback : Callable
707-
The callback function to be called with the method's return value
708-
**kwargs
709-
Additional keyword arguments for the callback registration
710-
"""
711-
raise NotImplementedError()
712-
713-
def abort_all_requests(self) -> None:
714-
"""Abort all ongoing requests in the inference engine.
715-
716-
Partial rollout API for canceling all queued and in-progress requests.
717-
"""
718-
raise NotImplementedError()
675+
raise NotImplementedError()

areal/api/engine_api.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -555,3 +555,15 @@ def pause(self):
555555
def resume(self):
556556
"""Resume request submission for async rollout."""
557557
raise NotImplementedError()
558+
559+
def get_scheduling_config(self) -> List[Scheduling]:
560+
"""Get the scheduling configuration for the engine.
561+
562+
This includes configuration such as container image, CPU/GPU/memory size.
563+
564+
Returns
565+
-------
566+
Scheduling
567+
The scheduling configuration for the engine
568+
"""
569+
raise NotImplementedError()

areal/api/scheduler_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class Worker:
1616
@dataclass
1717
class ScheduleStrategy:
1818
type: Literal["colocation", "separation", ""] = ""
19-
uid: str = ""
19+
target: str = ""
2020

2121

2222
@dataclass

areal/api/workflow_api.py

Lines changed: 51 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -524,7 +524,7 @@ def rollout_batch(
524524

525525
def prepare_batch(
526526
self,
527-
dataloader: StatefulDataLoader,
527+
dataloader: StatefulDataLoader | List[Dict[str, Any]],
528528
workflow: "RolloutWorkflow" | None = None,
529529
workflow_builder: Callable | None = None,
530530
should_accept: Callable | None = None,
@@ -533,28 +533,62 @@ def prepare_batch(
533533
534534
See :meth:`~areal.api.engine_api.InferenceEngine.prepare_batch` for detailed documentation.
535535
"""
536-
if not hasattr(self, "data_generator"):
537-
self.data_generator = cycle_dataloader(dataloader)
538-
assert dataloader.batch_size is not None
539-
while True:
540-
# Submit at least two batches to allow maximum overlap
541-
if (
542-
self.get_capacity() + dataloader.batch_size > 0
543-
and self.input_queue.qsize() + dataloader.batch_size
544-
< self.input_queue.maxsize
545-
):
546-
data = next(self.data_generator)
547-
for item in data:
536+
if isinstance(dataloader, StatefulDataLoader):
537+
# 处理StatefulDataLoader类型 - 保持原有逻辑不变
538+
if not hasattr(self, "data_generator"):
539+
self.data_generator = cycle_dataloader(dataloader)
540+
assert dataloader.batch_size is not None
541+
batch_size = dataloader.batch_size
542+
543+
while True:
544+
# Submit at least two batches to allow maximum overlap
545+
if (
546+
self.get_capacity() + batch_size > 0
547+
and self.input_queue.qsize() + batch_size
548+
< self.input_queue.maxsize
549+
):
550+
data = next(self.data_generator)
551+
for item in data:
552+
self.submit(
553+
item,
554+
workflow=workflow,
555+
workflow_builder=workflow_builder,
556+
should_accept=should_accept,
557+
)
558+
try:
559+
return self.wait(batch_size, timeout=1)
560+
except TimeoutError:
561+
pass
562+
else:
563+
self.data_list_index = 0
564+
565+
# 对于List类型,使用固定的batch_size=1
566+
batch_size = 1
567+
568+
while True:
569+
# Submit at least two batches to allow maximum overlap
570+
if (
571+
self.get_capacity() + batch_size > 0
572+
and self.input_queue.qsize() + batch_size
573+
< self.input_queue.maxsize
574+
):
575+
# 从List中获取数据,支持循环访问
576+
if self.data_list_index >= len(dataloader):
577+
self.data_list_index = 0 # 循环访问
578+
579+
item = dataloader[self.data_list_index]
580+
self.data_list_index += 1
581+
548582
self.submit(
549583
item,
550584
workflow=workflow,
551585
workflow_builder=workflow_builder,
552586
should_accept=should_accept,
553587
)
554-
try:
555-
return self.wait(dataloader.batch_size, timeout=1)
556-
except TimeoutError:
557-
pass
588+
try:
589+
return self.wait(batch_size, timeout=1)
590+
except TimeoutError:
591+
pass
558592

559593
def pause(self):
560594
"""Pause request submission for async rollout.
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
from concurrent.futures import ThreadPoolExecutor
2+
from functools import partial
3+
from typing import Any, Callable, Dict, List
4+
5+
from tensordict import TensorDict, stack
6+
7+
from areal.api.cli_args import InferenceEngineConfig
8+
from areal.api.controller_api import RolloutController, DistributedBatch
9+
from areal.api.engine_api import InferenceEngine
10+
from areal.api.io_struct import AllocationMode, WeightUpdateMeta
11+
from areal.api.workflow_api import RolloutWorkflow
12+
13+
from areal.api.scheduler_api import Job, Scheduler, ScheduleStrategy, Worker
14+
from areal.controller.utils import create_engine_with_retry, rpc_call
15+
from areal.utils.data import concat_padded_tensors
16+
from areal.utils import logging
17+
from areal.utils.http import wait_future_ordered
18+
19+
logger = logging.getLogger("DistributedRolloutController")
20+
21+
22+
class DistributedRolloutController(RolloutController):
23+
def __init__(
24+
self,
25+
inf_engine: InferenceEngine,
26+
config: InferenceEngineConfig,
27+
scheduler: Scheduler,
28+
):
29+
super().__init__(inf_engine, config, scheduler)
30+
self.role: str = "rollout"
31+
self.alloc_mode: AllocationMode
32+
self.enable_colocate_mode: bool
33+
self.dp_world_size: int
34+
self.dp_head_workers: List[Worker]
35+
36+
def initialize(
37+
self,
38+
alloc_mode_str: str,
39+
target: str,
40+
):
41+
self.alloc_mode = AllocationMode.from_str(alloc_mode_str)
42+
self.dp_world_size = self.alloc_mode.gen.world_size // self.alloc_mode.gen.dp_size
43+
44+
job = Job(
45+
replicas=self.alloc_mode.gen.world_size,
46+
tasks=self.inf_engine.get_scheduling_config(),
47+
schedule_strategy=ScheduleStrategy(type="colocation", target=target) if target else None,
48+
role=self.role,
49+
)
50+
logger.info(f"Start to create job: {job}")
51+
self.scheduler.create_workers(job)
52+
53+
workers = self.scheduler.get_workers(self.role, timeout=1800)
54+
self.dp_head_workers = [worker for idx, worker in enumerate(workers) if idx % self.dp_world_size == 0]
55+
assert len(self.dp_head_workers) == self.alloc_mode.gen.dp_size
56+
57+
engine_addrs = [f"{w.ip}:{w.serve_port}" for w in self.dp_head_workers]
58+
with ThreadPoolExecutor(max_workers=len(self.dp_head_workers)) as executor:
59+
futures = [
60+
executor.submit(
61+
partial(
62+
create_engine_with_retry,
63+
self.scheduler.create_engine,
64+
worker.id,
65+
self.inf_engine,
66+
None,
67+
engine_addrs,
68+
self.dp_world_size,
69+
)
70+
)
71+
for worker in self.dp_head_workers
72+
]
73+
74+
wait_future_ordered(futures, exit_on_exception=True)
75+
76+
def destroy(self):
77+
self.scheduler.delete_workers()
78+
79+
def __del__(self):
80+
self.destroy()
81+
82+
def update_weights(self, meta: WeightUpdateMeta) -> None:
83+
"""Update weights in the inference engine."""
84+
self.custom_function_call("update_weights", None, meta)
85+
return None
86+
87+
def prepare_batch(self, data: DistributedBatch, workflow: RolloutWorkflow) -> None:
88+
"""Asynchronously submit a request to the inference engine. Exits immediately."""
89+
batches = data.chunk(self.alloc_mode.gen.dp_size)
90+
self.custom_function_call("prepare_batch", batches, workflow)
91+
return None
92+
93+
def rollout_batch(
94+
self,
95+
data: DistributedBatch,
96+
workflow: RolloutWorkflow
97+
) -> DistributedBatch:
98+
"""Submit a batch of requests to the inference engine and wait for the results."""
99+
batches = data.chunk(self.alloc_mode.gen.dp_size)
100+
results = self.custom_function_call("rollout_distributed_batch", batches, workflow)
101+
assert len(results) > 0
102+
size = int(results[0]["input_ids"].shape[0])
103+
bs = size * len(results)
104+
padded = concat_padded_tensors(results)
105+
if isinstance(padded, dict):
106+
padded = TensorDict(padded, batch_size=[bs])
107+
return DistributedBatch.concat(padded.to_dict())
108+
109+
def set_version(self, version: int) -> None:
110+
self.custom_function_call("set_version", None, version)
111+
return None
112+
113+
def get_version(self) -> int:
114+
results = self.custom_function_call("get_version", None)
115+
return results[0]
116+
117+
def pause(self):
118+
self.custom_function_call("pause", None)
119+
120+
def resume(self):
121+
self.custom_function_call("resume", None)
122+
123+
def submit(self, data: DistributedBatch):
124+
batches = data.chunk(self.alloc_mode.gen.dp_size)
125+
self.custom_function_call("submit", batches)
126+
127+
def wait(self, counts: List[int], timeout: float | None = None)->DistributedBatch:
128+
assert len(counts) == len(self.dp_head_workers)
129+
results = self.custom_function_call("wait", counts, timeout)
130+
return DistributedBatch.concat(results)
131+
132+
def custom_function_call(self, method: str, batches, *args, **kwargs):
133+
return rpc_call(self.scheduler, self.dp_head_workers, method, batches, args, kwargs)

0 commit comments

Comments
 (0)