Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
71 commits
Select commit Hold shift + click to select a range
edf7b5d
.
garrett4wade Oct 22, 2025
337e71a
.
garrett4wade Oct 22, 2025
5ab09a2
merge main
garrett4wade Oct 23, 2025
427f3b0
Merge branch 'main' of https://github.com/inclusionAI/AReaL into fw/l…
garrett4wade Oct 24, 2025
a9dad5a
minor fix import
garrett4wade Oct 24, 2025
33a626b
Merge branch 'main' of https://github.com/inclusionAI/AReaL into fw/l…
garrett4wade Oct 27, 2025
fa0bfd0
Merge branch 'main' of https://github.com/inclusionAI/AReaL into fw/l…
garrett4wade Oct 27, 2025
f660e5b
merge inferece engine tests
garrett4wade Oct 27, 2025
78b489d
update
garrett4wade Oct 27, 2025
b0ecf14
Merge branch 'fw/local-inf-engine' of https://github.com/inclusionAI/…
garrett4wade Oct 27, 2025
722afad
fix
garrett4wade Oct 28, 2025
17945e9
merge main
garrett4wade Oct 28, 2025
7a2f6a9
.
garrett4wade Oct 28, 2025
46ee150
add local scheduler
garrett4wade Oct 28, 2025
b1eefc1
merge main
garrett4wade Oct 28, 2025
e471c1e
Merge branch 'fw/ls' of https://github.com/inclusionAI/AReaL into fw/…
garrett4wade Oct 28, 2025
266d6d6
implement run workflow endpoint and rolllout controller
garrett4wade Oct 28, 2025
f67dd60
add tensor serialization
garrett4wade Oct 29, 2025
a58c984
fix test
garrett4wade Oct 29, 2025
d14b53c
add scheduler and rollout controller test
garrett4wade Oct 29, 2025
b3a3e53
fix docstring and type annotations
garrett4wade Oct 29, 2025
f223db1
merge train controller commit
garrett4wade Oct 29, 2025
2969c9f
Merge branch 'main' of https://github.com/inclusionAI/AReaL into fw/l…
garrett4wade Oct 29, 2025
a58d0cc
add train controller
garrett4wade Oct 29, 2025
e049f30
init commit train controller
garrett4wade Oct 29, 2025
b4c4eb6
refactor train controller
garrett4wade Oct 29, 2025
5a702a1
add train controller tests
garrett4wade Oct 29, 2025
54ee6fd
renaming
garrett4wade Oct 29, 2025
b21e452
.
garrett4wade Oct 29, 2025
170cc75
update train script
garrett4wade Oct 29, 2025
157b0b0
implement rollout stats
garrett4wade Oct 29, 2025
7475004
.
garrett4wade Oct 29, 2025
6e54a58
fix
garrett4wade Oct 29, 2025
deee027
add sync rpc server
garrett4wade Oct 29, 2025
ece5152
refactor to http server instead of flask
garrett4wade Oct 29, 2025
69805e8
sft run
garrett4wade Oct 30, 2025
c37732c
fix sft; init grpo
garrett4wade Oct 30, 2025
e50b9b0
add rpc server configuration
garrett4wade Oct 30, 2025
a8e75de
except update weight
garrett4wade Oct 30, 2025
beeedd7
grpo run
garrett4wade Oct 30, 2025
9c96a3e
merge main
garrett4wade Oct 31, 2025
ce23d47
update to flask rpc server
garrett4wade Oct 31, 2025
bb60c35
add grpo example
garrett4wade Oct 31, 2025
ae1d6a2
remove local inference engine
garrett4wade Oct 31, 2025
a25d378
minor revert
garrett4wade Oct 31, 2025
99fe517
revert realhf
garrett4wade Oct 31, 2025
2830fac
Merge branch 'fw/local-inf-engine' of https://github.com/inclusionAI/…
garrett4wade Nov 2, 2025
7e133b8
minor config fix
garrett4wade Nov 2, 2025
73912a8
merge tests
garrett4wade Oct 31, 2025
a822cb2
fix docstring
garrett4wade Oct 31, 2025
6e62884
add test
garrett4wade Nov 1, 2025
98d2c8d
fix format
garrett4wade Oct 31, 2025
12cc12e
shorter ctx len for test
garrett4wade Nov 2, 2025
3ba98e6
add adv norm in grpo test
garrett4wade Nov 2, 2025
204b1fd
update test to use local path
garrett4wade Nov 3, 2025
95a08ac
resource cleanup in tests
garrett4wade Nov 3, 2025
d0dfad7
fix vllm pp
garrett4wade Nov 3, 2025
6749916
fix
garrett4wade Nov 3, 2025
52921f2
.
garrett4wade Nov 4, 2025
9258e2e
.
garrett4wade Nov 4, 2025
b640c11
Merge branch 'fw/msvt' of https://github.com/inclusionAI/AReaL into f…
garrett4wade Nov 4, 2025
4443c9b
.
garrett4wade Nov 4, 2025
c664acc
merge main
garrett4wade Nov 4, 2025
ac6a11a
add assertion
garrett4wade Nov 4, 2025
e85a39f
merge main
garrett4wade Nov 4, 2025
d212a38
merge
garrett4wade Nov 4, 2025
54fce97
revert and fix
garrett4wade Nov 4, 2025
eeffd1d
merge main
garrett4wade Nov 10, 2025
749e047
minor revert
garrett4wade Nov 10, 2025
41c2407
merge main
garrett4wade Nov 14, 2025
3b25e2f
Merge branch 'main' of https://github.com/inclusionAI/AReaL into fw/l…
garrett4wade Nov 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions areal/api/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,52 @@ class MegatronEngineConfig:
recompute_modules: list[str] | None = None


@dataclass
class SchedulingStrategy:
type: str = field(
default="separation", metadata={"choices": ["separation", "colocation"]}
)
target: str | None = field(
default=None, metadata={"help": "The target role to be colocated with"}
)


@dataclass
class SchedulingSpec:
cpu: int = field(default=0, metadata={"help": "Number of CPU cores required"})
gpu: int = field(default=0, metadata={"help": "Number of GPU units required"})
mem: int = field(default=0, metadata={"help": "Amount of memory (GB) required"})
port_count: int = field(default=2, metadata={"help": "Number of ports to expose"})
image: str = field(
default="", metadata={"help": "Docker/Singularity container image to use"}
)
type: str = field(
default="worker",
metadata={
"help": "Task type (e.g., worker, engine)",
"choices": ["worker", "engine"],
},
)
env_vars: dict[str, str] = field(
default_factory=dict,
metadata={"help": "Environment variables for the container"},
)
# cmd
cmd: str | None = field(
default=None,
metadata={
"help": "Command to execute inside the container. Defaults to AReaL's RPC server."
},
)
# slurm configurations from "https://slurm.schedmd.com/sbatch.html"
nodelist: str | None = None
exclude: str | None = None
partition: str | None = None
time_limit: str | None = None # see "--time" option for format
begin: str | None = None # see "--begin" option for format
deadline: str | None = None # see "--deadline" option for format


@dataclass
class TrainEngineConfig:
"""Core configuration for model training, including optimization and backend settings."""
Expand Down Expand Up @@ -442,6 +488,13 @@ class TrainEngineConfig:
default="lora",
metadata={"help": "peft method type. Only LoRA is supported for now."},
)
scheduling_spec: SchedulingSpec = field(
default_factory=lambda: SchedulingSpec(
cmd="python -m areal.scheduler.rpc.rpc_server"
),
metadata={"help": "train engine schedule specs"},
)
scheduling_strategy: SchedulingStrategy = field(default_factory=SchedulingStrategy)


@dataclass
Expand Down Expand Up @@ -924,6 +977,13 @@ class InferenceEngineConfig:
"help": "The grace period after calling /pause_generation. Wait until all requests have been dropped."
},
)
scheduling_spec: SchedulingSpec = field(
default_factory=lambda: SchedulingSpec(
cmd="python -m areal.scheduler.rpc.rpc_server"
),
metadata={"help": "inference engine schedule specs"},
)
scheduling_strategy: SchedulingStrategy = field(default_factory=SchedulingStrategy)


@dataclass
Expand Down
171 changes: 138 additions & 33 deletions areal/api/scheduler_api.py
Original file line number Diff line number Diff line change
@@ -1,68 +1,173 @@
import abc
from dataclasses import dataclass, field
from typing import Any

from areal.api.cli_args import SchedulingSpec, SchedulingStrategy


@dataclass
class Worker:
"""
Represents a worker process in the distributed system.

Attributes:
id: Unique identifier for the worker (e.g., "rollout/0", "actor/1").
ip: IP address where the worker is running.
ports: List of port numbers (as strings) allocated to this worker for RPC communication.
"""

id: str
# worker and engine deploy on the same machine, so ip are the same
ip: str
ports: list[str] = field(default_factory=list)
worker_ports: list[str] = field(default_factory=list)
engine_ports: list[str] = field(default_factory=list)


@dataclass
class ContainerSpec:
cpu: int = 0
gpu: int = 0
mem: int = 0
container_image: str = ""
cmd: str = ""
env_vars: dict[str, str] = field(default_factory=dict)
port_count: int = 2
class Job:
replicas: int = 0
tasks: list[SchedulingSpec] = field(default_factory=list)
scheduling_strategy: SchedulingStrategy | None = None
role: str = ""


@dataclass
class ScheduleStrategy:
type: str = ""
uid: str = ""
class Scheduler(abc.ABC):
"""
Abstract base class for schedulers that manage distributed worker processes.

A scheduler is responsible for:
- Creating and managing worker processes/containers.
- Allocating resources (GPUs, ports, memory).
- Creating and managing engine instances on workers.
- Facilitating RPC calls to engine methods.
"""

@dataclass
class SchedulingConfig:
replicas: int = 0
specs: list[ContainerSpec] = field(default_factory=list)
schedule_strategy: ScheduleStrategy | None = None
role: str = ""
@abc.abstractmethod
def create_workers(self, job: Job, *args, **kwargs) -> list[str]:
"""
Create and start worker processes for a specific role.

Args:
scheduler_config: Configuration specifying replicas, resources, and scheduling strategy.
*args: Additional positional arguments (implementation-specific).
**kwargs: Additional keyword arguments (implementation-specific).

class Scheduler(abc.ABC):
def create_workers(self, worker_key, scheduler_config, *args, **kwargs) -> str:
Returns:
List of worker IDs created (e.g., ["rollout/0", "rollout/1"]).

Raises:
WorkerCreationError: If worker creation fails.
ValueError: If scheduler_config is invalid.
"""
Start workers, return job id
raise NotImplementedError()

@abc.abstractmethod
def get_workers(self, role: str, timeout: int | None = None) -> list[Worker]:
"""
Wait for workers to be ready and return their information.

This method blocks until all workers for the specified role are ready
to accept RPC requests, or until the timeout is reached.

Args:
role: Role name to query (e.g., "rollout", "actor").
timeout: Maximum time to wait in seconds. None means use the default timeout.

Returns:
List of Worker objects containing worker ID, IP address, and allocated ports.

Raises:
WorkerNotFoundError: If no workers exist for the specified role.
WorkerFailedError: If any worker process has failed.
WorkerTimeoutError: If timeout is exceeded while waiting for workers.
"""
raise NotImplementedError()

def get_workers(self, worker_key, timeout=None) -> list[Worker]:
@abc.abstractmethod
def delete_workers(self, role: str | None = None):
"""
Wait and return worker list, including scheduling results such as ip and engine ports
(worker id, ip, ports)
Stop and clean up worker processes.

Args:
role: Specific role to delete. If None, all workers are deleted.

Raises:
WorkerNotFoundError: If the specified role doesn't exist.

Note:
This method should gracefully terminate workers and clean up resources.
It should not raise an exception if workers have already stopped.
"""
raise NotImplementedError()

def delete_workers(self):
"""stop all workers
@abc.abstractmethod
async def create_engine(self, worker_id: str, engine: str, *args, **kwargs) -> Any:
"""
Create an engine instance on a remote worker.

The engine parameter is a string import path (e.g., "areal.engine.ppo.actor.FSDPPPOActor")
that will be dynamically imported and instantiated on the worker.

Args:
worker_id: ID of the worker to create the engine on (e.g., "rollout/0").
engine: Import path to the engine class (e.g., "areal.engine.ppo.actor.FSDPPPOActor").
*args: Positional arguments passed to engine initialization.
**kwargs: Keyword arguments passed to engine initialization.

Raises exception if there is no such job, but passes if the job
has stopped either successfully or not.
Returns:
Result from engine initialization.

Raises:
WorkerNotFoundError: If the specified worker doesn't exist.
WorkerFailedError: If the worker process has failed.
EngineCreationError: If engine creation or initialization fails.
"""
raise NotImplementedError()

async def create_engine(self, worker_id, engine_obj, *args, **kwargs):
@abc.abstractmethod
def call_engine(self, worker_id: str, method: str, *args, **kwargs) -> Any:
"""
Create engine instance remotely
Call a method on an engine instance running on a worker (data plane operation).

This is the synchronous version. Use `async_call_engine` for async operations.

Args:
worker_id: ID of the worker hosting the engine (e.g., "rollout/0").
method: Name of the method to call on the engine.
*args: Positional arguments to pass to the method.
**kwargs: Keyword arguments to pass to the method.

Returns:
Result from the engine method call.

Raises:
WorkerNotFoundError: If the specified worker doesn't exist.
WorkerFailedError: If the worker process has failed.
EngineCallError: If the method call fails.
"""
raise NotImplementedError()

def call_engine(self, worker_id, method, *args, **kwargs):
@abc.abstractmethod
async def async_call_engine(
self, worker_id: str, method: str, *args, **kwargs
) -> Any:
"""
Data plane call
Async version of call_engine for calling engine methods asynchronously.

This is useful for concurrent operations or when integrating with async frameworks.

Args:
worker_id: ID of the worker hosting the engine (e.g., "rollout/0").
method: Name of the method to call on the engine.
*args: Positional arguments to pass to the method.
**kwargs: Keyword arguments to pass to the method.

Returns:
Result from the engine method call.

Raises:
WorkerNotFoundError: If the specified worker doesn't exist.
WorkerFailedError: If the worker process has failed.
EngineCallError: If the method call fails.
"""
raise NotImplementedError()
11 changes: 11 additions & 0 deletions areal/controller/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""Controller components for managing distributed training and inference."""

from areal.controller.batch import DistributedBatchMemory
from areal.controller.rollout_controller import RolloutController
from areal.controller.train_controller import TrainController

__all__ = [
"DistributedBatchMemory",
"RolloutController",
"TrainController",
]
36 changes: 10 additions & 26 deletions areal/controller/batch.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Union
from typing import Any

import torch
from torch import Tensor
Expand All @@ -9,6 +9,7 @@
convert_list_to_dict,
validate_dict_dataset,
)
from areal.utils.data import concat_padded_tensors
from areal.utils.datapack import ffd_allocate
from areal.utils.errors import FrameworkError

Expand All @@ -17,7 +18,7 @@ class DistributedBatchMemory(DistributedBatch):
dataset = None

@classmethod
def from_dict(cls, dict_dataset: Dict[str, Union[Tensor, Any]]):
def from_dict(cls, dict_dataset: dict[str, Tensor | Any]):
"""Create a DistributedBatchMemory from dictionary format dataset.

Parameters
Expand All @@ -36,7 +37,7 @@ def from_dict(cls, dict_dataset: Dict[str, Union[Tensor, Any]]):
return instance

@classmethod
def from_list(cls, list_dataset: List[Dict[str, Union[Tensor, Any]]]):
def from_list(cls, list_dataset: list[dict[str, Tensor | Any]]):
"""Create a DistributedBatchMemory from list format dataset.

Parameters
Expand Down Expand Up @@ -103,9 +104,9 @@ def chunk_by_ffd(
List of DistributedBatchMemory objects
"""
total_size = self._get_total_size()
assert (
total_size % group_size == 0
), "tensor length must be devided by group_size"
assert total_size % group_size == 0, (
"tensor length must be devided by group_size"
)

# Handle seqlen calculation for both tensor and scalar types
if "seqlen" in self.dataset.keys():
Expand Down Expand Up @@ -209,7 +210,7 @@ def _get_total_size(self) -> int:
# For scalar values, assume it's a single sample
return 1

def get_data(self) -> Dict[str, Union[torch.Tensor, Any]]:
def get_data(self) -> dict[str, torch.Tensor | Any]:
"""Get all data from the DistributedBatchMemory.

Returns
Expand Down Expand Up @@ -253,24 +254,7 @@ def concat(data: list["DistributedBatchMemory"]) -> "DistributedBatchMemory":
batch.dataset = {}
return batch

merged_data = {}
for batch in data:
for k, v in batch.dataset.items():
if k in merged_data:
if isinstance(merged_data[k], torch.Tensor) and isinstance(
v, torch.Tensor
):
merged_data[k] = torch.cat([merged_data[k], v], dim=0)
elif isinstance(merged_data[k], list) and isinstance(v, list):
merged_data[k] = merged_data[k] + v
else:
# Handle mixed types or scalar values
if isinstance(merged_data[k], list):
merged_data[k].append(v)
else:
merged_data[k] = [merged_data[k], v]
else:
merged_data[k] = v
merged_data = concat_padded_tensors([k.dataset for k in data])
result = DistributedBatchMemory.__new__(DistributedBatchMemory)
result.dataset = merged_data
return result
Expand Down Expand Up @@ -319,7 +303,7 @@ def __setitem__(self, key, value):
self.dataset[key] = value
else:
raise FrameworkError(
"FrameworkError", "DistributedBatchMemoryError", f"key must be str"
"FrameworkError", "DistributedBatchMemoryError", "key must be str"
)

def __delitem__(self, key):
Expand Down
Loading
Loading