Skip to content

Commit 3159d39

Browse files
author
chucai.dzq
committed
add rollout controller
impl rollout controller
1 parent 076c3ba commit 3159d39

File tree

10 files changed

+574
-100
lines changed

10 files changed

+574
-100
lines changed

areal/api/controller_api.py

Lines changed: 8 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ def set_version(self, version: int):
315315
"""
316316
raise NotImplementedError()
317317

318-
def get_version(self) -> int:
318+
def get_version(self) -> List[int]:
319319
"""Get the current weight version in the training engine.
320320
321321
Returns
@@ -359,7 +359,7 @@ def train_batch(
359359
input_: DistributedBatch,
360360
loss_fn: Callable[[torch.Tensor, Dict[str, Any]], torch.Tensor],
361361
loss_weight_fn: Callable[[Dict[str, Any]], torch.Tensor],
362-
) -> Dict[str, float]:
362+
) -> List[Dict[str, float]]:
363363
"""Update the model with a batch of data and a loss function.
364364
365365
Note
@@ -382,7 +382,7 @@ def train_batch(
382382
383383
Returns
384384
-------
385-
Dict[str, float]
385+
List[Dict[str, float]]
386386
Scalar statistics after training, e.g., the current learning rate,
387387
gradient norm, etc.
388388
"""
@@ -394,7 +394,7 @@ def eval_batch(
394394
input_: DistributedBatch,
395395
loss_fn: Callable[[torch.Tensor, Dict[str, Any]], torch.Tensor],
396396
loss_weight_fn: Callable[[Dict[str, Any]], torch.Tensor],
397-
) -> torch.Tensor | None:
397+
) -> List[torch.Tensor]:
398398
"""Evaluate the model using the forward pass and loss function.
399399
400400
Note
@@ -458,7 +458,6 @@ def forward(
458458
"""
459459
raise NotImplementedError()
460460

461-
462461
class RolloutController(abc.ABC):
463462
"""A centralized controller that manages multiple distributed InferenceEngine workers for rollout generation.
464463
@@ -508,21 +507,6 @@ def destroy(self):
508507
"""Destroy the engine and release GPU memory for the local inference engine."""
509508
raise NotImplementedError()
510509

511-
async def agenerate(self, req: ModelRequest) -> ModelResponse:
512-
"""Asynchronously generate a response for the given request.
513-
514-
Parameters
515-
----------
516-
req : ModelRequest
517-
The model request containing input data and generation parameters
518-
519-
Returns
520-
-------
521-
ModelResponse
522-
The generated response from the model
523-
"""
524-
raise NotImplementedError()
525-
526510
def update_weights(self, meta: WeightUpdateMeta) -> Future:
527511
"""Update weights in the inference engine in a non-blocking manner.
528512
@@ -571,7 +555,7 @@ def get_version(self) -> int:
571555

572556
def submit(
573557
self,
574-
data: Dict[str, Any],
558+
data: DistributedBatch,
575559
workflow: Optional["RolloutWorkflow"] = None,
576560
workflow_builder: Optional[Callable] = None,
577561
should_accept: Callable | None = None,
@@ -623,7 +607,7 @@ def wait(self, count: int, timeout: float | None = None) -> DistributedBatch:
623607

624608
def rollout_batch(
625609
self,
626-
data: List[Dict[str, Any]],
610+
data: DistributedBatch,
627611
workflow: Optional["RolloutWorkflow"] = None,
628612
workflow_builder: Optional[Callable] = None,
629613
should_accept: Callable | None = None,
@@ -652,7 +636,7 @@ def rollout_batch(
652636

653637
def prepare_batch(
654638
self,
655-
dataloader: StatefulDataLoader,
639+
dataloader: DistributedBatch,
656640
workflow: Optional["RolloutWorkflow"] = None,
657641
workflow_builder: Optional[Callable] = None,
658642
should_accept: Callable | None = None,
@@ -688,31 +672,4 @@ def pause(self):
688672

689673
def resume(self):
690674
"""Resume request submission for async rollout."""
691-
raise NotImplementedError()
692-
693-
def register_callback_to_all_worker(
694-
self, method: str, callback: Callable, **kwargs
695-
):
696-
"""Register a callback function for the specified method across all workers.
697-
698-
Partial rollout API. After successful registration, the controller will poll
699-
and call the specified method in a background thread. When the return value
700-
is obtained, it will be used as a parameter to call the `callback` function.
701-
702-
Parameters
703-
----------
704-
method : str
705-
The name of the method to register the callback for
706-
callback : Callable
707-
The callback function to be called with the method's return value
708-
**kwargs
709-
Additional keyword arguments for the callback registration
710-
"""
711-
raise NotImplementedError()
712-
713-
def abort_all_requests(self) -> None:
714-
"""Abort all ongoing requests in the inference engine.
715-
716-
Partial rollout API for canceling all queued and in-progress requests.
717-
"""
718-
raise NotImplementedError()
675+
raise NotImplementedError()

areal/api/engine_api.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ class Scheduling:
2525
cpu: int
2626
gpu: int
2727
mem: int
28+
port_count: int
29+
cmd: str | None = None
2830
nodelist: str | None = None
2931
exclude: str | None = None
3032
partition: str | None = None
@@ -138,7 +140,7 @@ def parallelism_group(self) -> dist.ProcessGroup:
138140
"""
139141
raise NotImplementedError()
140142

141-
def get_scheduling_config(self) -> Scheduling:
143+
def get_scheduling_config(self) -> List[Scheduling]:
142144
"""Get the scheduling configuration for the engine.
143145
144146
This includes configuration such as container image, CPU/GPU/memory size.
@@ -553,3 +555,15 @@ def pause(self):
553555
def resume(self):
554556
"""Resume request submission for async rollout."""
555557
raise NotImplementedError()
558+
559+
def get_scheduling_config(self) -> List[Scheduling]:
560+
"""Get the scheduling configuration for the engine.
561+
562+
This includes configuration such as container image, CPU/GPU/memory size.
563+
564+
Returns
565+
-------
566+
Scheduling
567+
The scheduling configuration for the engine
568+
"""
569+
raise NotImplementedError()

areal/api/scheduler_api.py

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,40 @@
11
import abc
22
from dataclasses import dataclass, field
3-
from typing import Dict, List
3+
from typing import List, Literal
4+
5+
from areal.api.engine_api import Scheduling
46

57

68
@dataclass
79
class Worker:
810
id: str
911
ip: str
10-
ports: List[str] = field(default_factory=list)
11-
12-
13-
@dataclass
14-
class ContainerSpec:
15-
cpu: int = 0
16-
gpu: int = 0
17-
mem: int = 0
18-
container_image: str = ""
19-
cmd: str = ""
20-
env_vars: Dict[str, str] = field(default_factory=dict)
21-
port_count: int = 2
12+
serve_port: str
13+
extra_ports: List[str] = field(default_factory=list)
2214

2315

2416
@dataclass
2517
class ScheduleStrategy:
26-
type: str = ""
27-
uid: str = ""
18+
type: Literal["colocation", "separation", ""] = ""
19+
target: str = ""
2820

2921

3022
@dataclass
31-
class SchedulingConfig:
23+
class Job:
3224
replicas: int = 0
33-
specs: List[ContainerSpec] = field(default_factory=list)
25+
tasks: List[Scheduling] = field(default_factory=list)
3426
schedule_strategy: ScheduleStrategy | None = None
3527
role: str = ""
3628

3729

3830
class Scheduler(abc.ABC):
3931
def create_workers(self, worker_key, scheduler_config, *args, **kwargs) -> None:
4032
"""
41-
Start workers, return job id
33+
Start workers
4234
"""
4335
raise NotImplementedError()
4436

45-
def get_workers(self, worker_key, timeout=None) -> List[Worker]:
37+
def get_workers(self, role: str, timeout=None) -> List[Worker]:
4638
"""
4739
Wait and return worker list, including scheduling results such as ip and engine ports
4840
(worker id, ip, ports)

areal/api/workflow_api.py

Lines changed: 51 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -521,7 +521,7 @@ def rollout_batch(
521521

522522
def prepare_batch(
523523
self,
524-
dataloader: StatefulDataLoader,
524+
dataloader: StatefulDataLoader | List[Dict[str, Any]],
525525
workflow: "RolloutWorkflow" | None = None,
526526
workflow_builder: Callable | None = None,
527527
should_accept: Callable | None = None,
@@ -530,28 +530,62 @@ def prepare_batch(
530530
531531
See :meth:`~areal.api.engine_api.InferenceEngine.prepare_batch` for detailed documentation.
532532
"""
533-
if not hasattr(self, "data_generator"):
534-
self.data_generator = cycle_dataloader(dataloader)
535-
assert dataloader.batch_size is not None
536-
while True:
537-
# Submit at least two batches to allow maximum overlap
538-
if (
539-
self.get_capacity() + dataloader.batch_size > 0
540-
and self.input_queue.qsize() + dataloader.batch_size
541-
< self.input_queue.maxsize
542-
):
543-
data = next(self.data_generator)
544-
for item in data:
533+
if isinstance(dataloader, StatefulDataLoader):
534+
# 处理StatefulDataLoader类型 - 保持原有逻辑不变
535+
if not hasattr(self, "data_generator"):
536+
self.data_generator = cycle_dataloader(dataloader)
537+
assert dataloader.batch_size is not None
538+
batch_size = dataloader.batch_size
539+
540+
while True:
541+
# Submit at least two batches to allow maximum overlap
542+
if (
543+
self.get_capacity() + batch_size > 0
544+
and self.input_queue.qsize() + batch_size
545+
< self.input_queue.maxsize
546+
):
547+
data = next(self.data_generator)
548+
for item in data:
549+
self.submit(
550+
item,
551+
workflow=workflow,
552+
workflow_builder=workflow_builder,
553+
should_accept=should_accept,
554+
)
555+
try:
556+
return self.wait(batch_size, timeout=1)
557+
except TimeoutError:
558+
pass
559+
else:
560+
self.data_list_index = 0
561+
562+
# 对于List类型,使用固定的batch_size=1
563+
batch_size = 1
564+
565+
while True:
566+
# Submit at least two batches to allow maximum overlap
567+
if (
568+
self.get_capacity() + batch_size > 0
569+
and self.input_queue.qsize() + batch_size
570+
< self.input_queue.maxsize
571+
):
572+
# 从List中获取数据,支持循环访问
573+
if self.data_list_index >= len(dataloader):
574+
self.data_list_index = 0 # 循环访问
575+
576+
item = dataloader[self.data_list_index]
577+
self.data_list_index += 1
578+
545579
self.submit(
546580
item,
547581
workflow=workflow,
548582
workflow_builder=workflow_builder,
549583
should_accept=should_accept,
550584
)
551-
try:
552-
return self.wait(dataloader.batch_size, timeout=1)
553-
except TimeoutError:
554-
pass
585+
try:
586+
return self.wait(batch_size, timeout=1)
587+
except TimeoutError:
588+
pass
555589

556590
def pause(self):
557591
"""Pause request submission for async rollout.

0 commit comments

Comments
 (0)