Skip to content
Draft
Show file tree
Hide file tree
Changes from 66 commits
Commits
Show all changes
71 commits
Select commit Hold shift + click to select a range
edf7b5d
.
garrett4wade Oct 22, 2025
337e71a
.
garrett4wade Oct 22, 2025
5ab09a2
merge main
garrett4wade Oct 23, 2025
427f3b0
Merge branch 'main' of https://github.com/inclusionAI/AReaL into fw/l…
garrett4wade Oct 24, 2025
a9dad5a
minor fix import
garrett4wade Oct 24, 2025
33a626b
Merge branch 'main' of https://github.com/inclusionAI/AReaL into fw/l…
garrett4wade Oct 27, 2025
fa0bfd0
Merge branch 'main' of https://github.com/inclusionAI/AReaL into fw/l…
garrett4wade Oct 27, 2025
f660e5b
merge inferece engine tests
garrett4wade Oct 27, 2025
78b489d
update
garrett4wade Oct 27, 2025
b0ecf14
Merge branch 'fw/local-inf-engine' of https://github.com/inclusionAI/…
garrett4wade Oct 27, 2025
722afad
fix
garrett4wade Oct 28, 2025
17945e9
merge main
garrett4wade Oct 28, 2025
7a2f6a9
.
garrett4wade Oct 28, 2025
46ee150
add local scheduler
garrett4wade Oct 28, 2025
b1eefc1
merge main
garrett4wade Oct 28, 2025
e471c1e
Merge branch 'fw/ls' of https://github.com/inclusionAI/AReaL into fw/…
garrett4wade Oct 28, 2025
266d6d6
implement run workflow endpoint and rolllout controller
garrett4wade Oct 28, 2025
f67dd60
add tensor serialization
garrett4wade Oct 29, 2025
a58c984
fix test
garrett4wade Oct 29, 2025
d14b53c
add scheduler and rollout controller test
garrett4wade Oct 29, 2025
b3a3e53
fix docstring and type annotations
garrett4wade Oct 29, 2025
f223db1
merge train controller commit
garrett4wade Oct 29, 2025
2969c9f
Merge branch 'main' of https://github.com/inclusionAI/AReaL into fw/l…
garrett4wade Oct 29, 2025
a58d0cc
add train controller
garrett4wade Oct 29, 2025
e049f30
init commit train controller
garrett4wade Oct 29, 2025
b4c4eb6
refactor train controller
garrett4wade Oct 29, 2025
5a702a1
add train controller tests
garrett4wade Oct 29, 2025
54ee6fd
renaming
garrett4wade Oct 29, 2025
b21e452
.
garrett4wade Oct 29, 2025
170cc75
update train script
garrett4wade Oct 29, 2025
157b0b0
implement rollout stats
garrett4wade Oct 29, 2025
7475004
.
garrett4wade Oct 29, 2025
6e54a58
fix
garrett4wade Oct 29, 2025
deee027
add sync rpc server
garrett4wade Oct 29, 2025
ece5152
refactor to http server instead of flask
garrett4wade Oct 29, 2025
69805e8
sft run
garrett4wade Oct 30, 2025
c37732c
fix sft; init grpo
garrett4wade Oct 30, 2025
e50b9b0
add rpc server configuration
garrett4wade Oct 30, 2025
a8e75de
except update weight
garrett4wade Oct 30, 2025
beeedd7
grpo run
garrett4wade Oct 30, 2025
9c96a3e
merge main
garrett4wade Oct 31, 2025
ce23d47
update to flask rpc server
garrett4wade Oct 31, 2025
bb60c35
add grpo example
garrett4wade Oct 31, 2025
ae1d6a2
remove local inference engine
garrett4wade Oct 31, 2025
a25d378
minor revert
garrett4wade Oct 31, 2025
99fe517
revert realhf
garrett4wade Oct 31, 2025
2830fac
Merge branch 'fw/local-inf-engine' of https://github.com/inclusionAI/…
garrett4wade Nov 2, 2025
7e133b8
minor config fix
garrett4wade Nov 2, 2025
73912a8
merge tests
garrett4wade Oct 31, 2025
a822cb2
fix docstring
garrett4wade Oct 31, 2025
6e62884
add test
garrett4wade Nov 1, 2025
98d2c8d
fix format
garrett4wade Oct 31, 2025
12cc12e
shorter ctx len for test
garrett4wade Nov 2, 2025
3ba98e6
add adv norm in grpo test
garrett4wade Nov 2, 2025
204b1fd
update test to use local path
garrett4wade Nov 3, 2025
95a08ac
resource cleanup in tests
garrett4wade Nov 3, 2025
d0dfad7
fix vllm pp
garrett4wade Nov 3, 2025
6749916
fix
garrett4wade Nov 3, 2025
52921f2
.
garrett4wade Nov 4, 2025
9258e2e
.
garrett4wade Nov 4, 2025
b640c11
Merge branch 'fw/msvt' of https://github.com/inclusionAI/AReaL into f…
garrett4wade Nov 4, 2025
4443c9b
.
garrett4wade Nov 4, 2025
c664acc
merge main
garrett4wade Nov 4, 2025
ac6a11a
add assertion
garrett4wade Nov 4, 2025
e85a39f
merge main
garrett4wade Nov 4, 2025
d212a38
merge
garrett4wade Nov 4, 2025
54fce97
revert and fix
garrett4wade Nov 4, 2025
eeffd1d
merge main
garrett4wade Nov 10, 2025
749e047
minor revert
garrett4wade Nov 10, 2025
41c2407
merge main
garrett4wade Nov 14, 2025
3b25e2f
Merge branch 'main' of https://github.com/inclusionAI/AReaL into fw/l…
garrett4wade Nov 14, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 102 additions & 35 deletions areal/api/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Any

import uvloop
import yaml
Expand Down Expand Up @@ -311,6 +312,52 @@ class MegatronEngineConfig:
recompute_modules: list[str] | None = None


@dataclass
class SchedulingStrategy:
type: str = field(
default="separation", metadata={"choices": ["separation", "colocation"]}
)
target: str | None = field(
default=None, metadata={"help": "The target role to be colocated with"}
)


@dataclass
class SchedulingSpec:
cpu: int = field(default=0, metadata={"help": "Number of CPU cores required"})
gpu: int = field(default=0, metadata={"help": "Number of GPU units required"})
mem: int = field(default=0, metadata={"help": "Amount of memory (GB) required"})
port_count: int = field(default=2, metadata={"help": "Number of ports to expose"})
image: str = field(
default="", metadata={"help": "Docker/Singularity container image to use"}
)
type: str = field(
default="worker",
metadata={
"help": "Task type (e.g., worker, engine)",
"choices": ["worker", "engine"],
},
)
env_vars: dict[str, str] = field(
default_factory=dict,
metadata={"help": "Environment variables for the container"},
)
# cmd
cmd: str | None = field(
default=None,
metadata={
"help": "Command to execute inside the container. Defaults to AReaL's RPC server."
},
)
# slurm configurations from "https://slurm.schedmd.com/sbatch.html"
nodelist: str | None = None
exclude: str | None = None
partition: str | None = None
time_limit: str | None = None # see "--time" option for format
begin: str | None = None # see "--begin" option for format
deadline: str | None = None # see "--deadline" option for format


@dataclass
class TrainEngineConfig:
"""Core configuration for model training, including optimization and backend settings."""
Expand Down Expand Up @@ -384,6 +431,13 @@ class TrainEngineConfig:
default="lora",
metadata={"help": "peft method type. Only LoRA is supported for now."},
)
scheduling_spec: SchedulingSpec = field(
default_factory=lambda: SchedulingSpec(
cmd="python -m areal.scheduler.rpc.rpc_server"
),
metadata={"help": "train engine schedule specs"},
)
scheduling_strategy: SchedulingStrategy = field(default_factory=SchedulingStrategy)


@dataclass
Expand Down Expand Up @@ -538,6 +592,24 @@ class PPOCriticConfig(TrainEngineConfig):
)


def get_py_cmd(module: str, args: dict[str, Any]):
# convert to flags
cmd = ["python3", "-m", module]
for k, v in args.items():
if v is None or v is False or v == "" or (isinstance(v, list) and not v):
continue
flag = f"--{k.replace('_', '-')}"
if v is True:
cmd.append(flag)
elif isinstance(v, list):
cmd.append(flag)
cmd.extend(map(str, v))
else:
cmd.append(flag)
cmd.append(str(v))
return cmd


@dataclass
class vLLMConfig:
"""Configuration for vLLM runtime. Refer to:
Expand Down Expand Up @@ -598,6 +670,10 @@ def build_args(
)
return args

@staticmethod
def build_cmd_from_args(args: dict[str, Any]):
return get_py_cmd("areal.thirdparty.vllm.areal_vllm_server", args)

@staticmethod
def build_cmd(
vllm_config: "vLLMConfig",
Expand All @@ -615,18 +691,7 @@ def build_cmd(
port=port,
dist_init_addr=dist_init_addr,
)
# convert to flags
flags = []
for k, v in args.items():
if v is None or v is False or v == "":
continue
if v is True:
flags.append(f"--{k.replace('_', '-')}")
elif isinstance(v, list):
flags.append(f"--{k.replace('_', '-')} {' '.join(map(str, v))}")
else:
flags.append(f"--{k.replace('_', '-')} {v}")
return f"python3 -m areal.thirdparty.vllm.areal_vllm_server {' '.join(flags)}"
return vLLMConfig.build_cmd_from_args(args)


@dataclass
Expand Down Expand Up @@ -724,28 +789,19 @@ def build_cmd(
node_rank=node_rank,
)

# convert to flags
flags = []
for k, v in args.items():
if is_version_less("sglang", "0.4.10.post2") and "max_loaded_loras" in k:
continue
if v is None or v is False or v == "":
continue
if v is True:
flags.append(f"--{k.replace('_', '-')}")
elif isinstance(v, list):
flags.append(f"--{k.replace('_', '-')} {' '.join(map(str, v))}")
else:
flags.append(f"--{k.replace('_', '-')} {v}")
return f"python3 -m sglang.launch_server {' '.join(flags)}"
return SGLangConfig.build_cmd_from_args(args)

@staticmethod
def build_cmd_from_args(args: dict[str, Any]):
return get_py_cmd("sglang.launch_server", args)

@staticmethod
def build_args(
sglang_config: "SGLangConfig",
tp_size,
base_gpu_id,
host,
port,
tp_size: int,
base_gpu_id: int,
host: str | None = None,
port: str | None = None,
dist_init_addr: str | None = None,
n_nodes: int = 1,
node_rank: int = 0,
Expand All @@ -761,19 +817,17 @@ def build_args(
enable_multithread_load=sglang_config.enable_multithread_load,
enable_fast_load=sglang_config.enable_fast_load,
)
args.pop("enable_multithread_load", None)
args.pop("enable_fast_load", None)
args["model_loader_extra_config"] = json.dumps(
model_loader_extra_config, separators=(",", ":")
)
args.pop("enable_multithread_load", None)
args.pop("enable_fast_load", None)
# Map "all-linear" to "all"
if "lora_target_modules" in args and args["lora_target_modules"]:
args["lora_target_modules"] = [
x.replace("-linear", "") for x in args["lora_target_modules"]
]
args = dict(
host=host,
port=port,
# Model and tokenizer
tokenizer_path=sglang_config.model_path,
tokenizer_mode="auto",
Expand All @@ -791,8 +845,14 @@ def build_args(
dist_init_addr=dist_init_addr,
**args,
)
if host is not None:
args["host"] = host
if port is not None:
args["port"] = port
if not pkg_version.is_version_greater_or_equal("sglang", "0.4.9.post2"):
raise RuntimeError("Needs sglang>=0.4.9.post2 to run the code.")
if is_version_less("sglang", "0.4.10.post2"):
args.pop("max_loaded_loras", None)
return args


Expand All @@ -811,7 +871,7 @@ class InferenceEngineConfig:
)
queue_size: None | int = field(
default=None,
metadata={"help": "Input/Output queue size for async rollout."},
metadata={"help": "(Deprecated) Input/Output queue size for async rollout."},
)
consumer_batch_size: int = field(
default=1,
Expand Down Expand Up @@ -859,6 +919,13 @@ class InferenceEngineConfig:
"help": "The grace period after calling /pause_generation. Wait until all requests have been dropped."
},
)
scheduling_spec: SchedulingSpec = field(
default_factory=lambda: SchedulingSpec(
cmd="python -m areal.scheduler.rpc.rpc_server"
),
metadata={"help": "inference engine schedule specs"},
)
scheduling_strategy: SchedulingStrategy = field(default_factory=SchedulingStrategy)


@dataclass
Expand Down
Loading
Loading