Skip to content
Open
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
393c2e7
.
ZiyiTsang Oct 14, 2025
889ceb7
.
ZiyiTsang Oct 14, 2025
3d3bfb6
.
ZiyiTsang Oct 14, 2025
af5a930
.
ZiyiTsang Oct 14, 2025
4453243
Update areal/experimental/megatron_actor.py
ZiyiTsang Oct 14, 2025
841d16a
.
ZiyiTsang Oct 14, 2025
553b3e9
Merge branch 'modify_dapo' of https://github.com/ZiyiTsang/AReaL into…
ZiyiTsang Oct 14, 2025
f938693
.
ZiyiTsang Oct 14, 2025
5b12bf9
Update areal/utils/data.py
ZiyiTsang Oct 14, 2025
8c2ddc8
Update areal/utils/functional.py
ZiyiTsang Oct 14, 2025
44c7072
Update examples/experimental/dapo/gsm8k_dapo.py
ZiyiTsang Oct 14, 2025
576e342
Update examples/experimental/dapo/gsm8k_dapo.py
ZiyiTsang Oct 14, 2025
2cb71de
Merge branch 'main' into modify_dapo
ZiyiTsang Oct 16, 2025
13f5729
.
ZiyiTsang Oct 16, 2025
889a884
.
ZiyiTsang Oct 16, 2025
b91dbb4
.
ZiyiTsang Oct 18, 2025
c7e16a4
.
ZiyiTsang Oct 18, 2025
56944c3
.
ZiyiTsang Oct 18, 2025
41e8025
.
ZiyiTsang Oct 18, 2025
249c428
Update docs/cli_reference.md
ZiyiTsang Oct 18, 2025
60e4c82
.
ZiyiTsang Oct 18, 2025
8e6ec9c
Merge branch 'modify_dapo' of https://github.com/ZiyiTsang/AReaL into…
ZiyiTsang Oct 18, 2025
9aa2ec4
Update areal/api/cli_args.py
ZiyiTsang Oct 19, 2025
bdf0ae5
Update docs/cli_reference.md
ZiyiTsang Oct 19, 2025
32ad09f
.
ZiyiTsang Oct 22, 2025
2d0ffac
.
ZiyiTsang Oct 23, 2025
c03edc3
Merge remote-tracking branch 'upstream/main' into modify_dapo
ZiyiTsang Oct 23, 2025
b2d7042
.
ZiyiTsang Oct 24, 2025
8c8583d
Merge branch 'main' into modify_dapo
ZiyiTsang Oct 24, 2025
1db2f50
Update recipe/AEnt/actor.py
ZiyiTsang Oct 24, 2025
8d82246
.
ZiyiTsang Oct 24, 2025
e88d29a
Merge branch 'modify_dapo' of https://github.com/ZiyiTsang/AReaL into…
ZiyiTsang Oct 24, 2025
f949c0f
.
ZiyiTsang Oct 26, 2025
f89f1fa
.
ZiyiTsang Oct 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions areal/api/alloc_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,6 +839,7 @@ def parse(self, expression: str):
AllocationValidationError: When validation rules are violated
ValueError: When parsing fails
"""

try:
tree = self.parser.parse(expression)
transformer = _ParallelStrategyTransformer()
Expand Down
72 changes: 34 additions & 38 deletions areal/api/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,19 @@
import os
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Dict, List

import uvloop
import yaml

from areal.utils.pkg_version import is_version_less

uvloop.install()
from hydra import compose as hydra_compose
from hydra import initialize as hydra_init
from hydra.core.global_hydra import GlobalHydra
from omegaconf import MISSING, DictConfig, OmegaConf

from areal.platforms import current_platform
from areal.utils import name_resolve, pkg_version
from areal.utils.pkg_version import is_version_less

uvloop.install()


@dataclass
Expand Down Expand Up @@ -129,11 +127,11 @@ class GenerationHyperparameters:
default=1.0,
metadata={"help": "Sampling temperature. Higher values increase diversity."},
)
stop_token_ids: List[int] = field(
stop_token_ids: list[int] = field(
default_factory=list,
metadata={"help": "Stop generation when encountering these token IDs."},
)
stop: List[str] | None = field(
stop: list[str] | None = field(
default=None,
metadata={
"help": "One or multiple stop words. Generation will stop if one of these words is sampled."
Expand Down Expand Up @@ -232,7 +230,7 @@ class OptimizerConfig:
class FSDPWrapPolicy:
"""Policy configuration for FSDP model layer wrapping. None defaults to wrapping transformer decoder layers defined by transformers."""

transformer_layer_cls_to_wrap: List[str] | None = field(
transformer_layer_cls_to_wrap: list[str] | None = field(
default=None,
metadata={"help": "A list of transformer layer names for FSDP to wrap."},
)
Expand Down Expand Up @@ -310,7 +308,7 @@ class MegatronEngineConfig:
recompute_method: str | None = "uniform"
recompute_num_layers: int | None = 1
distribute_saved_activations: bool | None = None
recompute_modules: List[str] | None = None
recompute_modules: list[str] | None = None


@dataclass
Expand Down Expand Up @@ -378,7 +376,7 @@ class TrainEngineConfig:
)
lora_rank: int = field(default=32, metadata={"help": "lora rank"})
lora_alpha: int = field(default=16, metadata={"help": "lora alpha"})
target_modules: List[str] = field(
target_modules: list[str] = field(
default_factory=list,
metadata={"help": "lora target_modules."},
)
Expand Down Expand Up @@ -486,12 +484,10 @@ class PPOActorConfig(TrainEngineConfig):
},
)
# Advanced Options
dynamic_sampling: bool = field(
default=False,
dynamic_sampling_strategy: str = field(
default="none",
metadata={
"help": "Enable dynamic sampling (within DAPO). If enabled, groups with the same reward will be masked out. "
"Note that enabling this option will lead to variable batch sizes. If you want to use a constant batch size with dynamic filtering, "
"you should use the `should_accept` parameter in `rollout_batch` and `prepare_batch`."
"help": "Dynamic sampling strategy. Select from `none`, `dynamic` and `static`. See the doc for more details"
},
)

Expand All @@ -500,7 +496,7 @@ class PPOActorConfig(TrainEngineConfig):
default=False,
metadata={"help": "Log statistics for agent trajectories"},
)
log_agent_stats_keys: List[str] = field(
log_agent_stats_keys: list[str] = field(
default_factory=lambda: [],
metadata={"help": "Keys for logging agent trajectory statistics"},
)
Expand Down Expand Up @@ -574,7 +570,7 @@ def build_args(
port,
dist_init_addr: str | None = None,
):
args: Dict = conf_as_dict(vllm_config)
args: dict = conf_as_dict(vllm_config)
args = dict(
host=host,
port=port,
Expand Down Expand Up @@ -608,11 +604,11 @@ def build_cmd(
if v is None or v is False or v == "":
continue
if v is True:
flags.append(f"--{k.replace('_','-')}")
flags.append(f"--{k.replace('_', '-')}")
elif isinstance(v, list):
flags.append(f"--{k.replace('_','-')} {' '.join(map(str, v))}")
flags.append(f"--{k.replace('_', '-')} {' '.join(map(str, v))}")
else:
flags.append(f"--{k.replace('_','-')} {v}")
flags.append(f"--{k.replace('_', '-')} {v}")
return f"python3 -m areal.thirdparty.vllm.areal_vllm_server {' '.join(flags)}"


Expand All @@ -638,7 +634,7 @@ class SGLangConfig:
enable_torch_compile: bool = False
torch_compile_max_bs: int = 32
cuda_graph_max_bs: int | None = None
cuda_graph_bs: List[int] | None = None
cuda_graph_bs: list[int] | None = None
torchao_config: str = ""
enable_nan_detection: bool = False
enable_p2p_check: bool = False
Expand Down Expand Up @@ -667,8 +663,8 @@ class SGLangConfig:
# lora
enable_lora: bool | None = None
max_lora_rank: int | None = None
lora_target_modules: List[str] | None = None
lora_paths: List[str] | None = None
lora_target_modules: list[str] | None = None
lora_paths: list[str] | None = None
max_loaded_loras: int = 1
max_loras_per_batch: int = 1
lora_backend: str = "triton"
Expand Down Expand Up @@ -719,11 +715,11 @@ def build_cmd(
if v is None or v is False or v == "":
continue
if v is True:
flags.append(f"--{k.replace('_','-')}")
flags.append(f"--{k.replace('_', '-')}")
elif isinstance(v, list):
flags.append(f"--{k.replace('_','-')} {' '.join(map(str, v))}")
flags.append(f"--{k.replace('_', '-')} {' '.join(map(str, v))}")
else:
flags.append(f"--{k.replace('_','-')} {v}")
flags.append(f"--{k.replace('_', '-')} {v}")
return f"python3 -m sglang.launch_server {' '.join(flags)}"

@staticmethod
Expand All @@ -738,11 +734,11 @@ def build_args(
node_rank: int = 0,
):
# Map "all-linear" to "all"
args: Dict = conf_as_dict(sglang_config)
args: dict = conf_as_dict(sglang_config)
if sglang_config.enable_multithread_load or sglang_config.enable_fast_load:
assert pkg_version.is_version_equal(
"sglang", "0.5.2"
), f"Customized model loading requires exact SGLang version 0.5.2"
assert pkg_version.is_version_equal("sglang", "0.5.2"), (
"Customized model loading requires exact SGLang version 0.5.2"
)
model_loader_extra_config = dict(
enable_multithread_load=sglang_config.enable_multithread_load,
enable_fast_load=sglang_config.enable_fast_load,
Expand Down Expand Up @@ -915,8 +911,8 @@ class WandBConfig:
job_type: str | None = None
group: str | None = None
notes: str | None = None
tags: List[str] | None = None
config: Dict | None = None
tags: list[str] | None = None
config: dict | None = None
id_suffix: str | None = "train"


Expand All @@ -926,7 +922,7 @@ class SwanlabConfig:

project: str | None = None
name: str | None = None
config: Dict | None = None
config: dict | None = None
logdir: str | None = None
mode: str | None = "disabled"
api_key: str | None = os.getenv("SWANLAB_API_KEY", None)
Expand Down Expand Up @@ -1023,7 +1019,7 @@ class SchedulerConfig:
endpoint: str = field(default="http://localhost:8081")
deploy_mode: str = field(default="separation")
functioncall_service_domain: str = field(default="http://localhost:8080")
reward_functioncall_config: Dict = field(default_factory=dict)
reward_functioncall_config: dict = field(default_factory=dict)
reward_model_path: str = field(default="")
reward_model_service_url: str = field(default="http://localhost:30000/classify")

Expand Down Expand Up @@ -1076,7 +1072,7 @@ class SlurmLauncherConfig:
default="--mpi=pmi2 -K --chdir $PWD",
metadata={"help": "Additional arguments to pass to the srun command."},
)
additional_bash_cmds: List[str] | None = field(
additional_bash_cmds: list[str] | None = field(
default=None,
metadata={
"help": "Additional bash commands to setup the container before running "
Expand Down Expand Up @@ -1244,7 +1240,7 @@ class PPOConfig(GRPOConfig):
critic: PPOCriticConfig = field(default_factory=PPOCriticConfig)


def parse_cli_args(argv: List[str]):
def parse_cli_args(argv: list[str]):
parser = argparse.ArgumentParser()
parser.add_argument(
"--config", help="Path to the main configuration file", required=True
Expand Down Expand Up @@ -1277,7 +1273,7 @@ def to_structured_cfg(cfg, config_cls):
return cfg


def load_expr_config(argv: List[str], config_cls):
def load_expr_config(argv: list[str], config_cls):
cfg, config_file = parse_cli_args(argv)
cfg = to_structured_cfg(cfg, config_cls=config_cls)
cfg = OmegaConf.to_object(cfg)
Expand Down Expand Up @@ -1305,7 +1301,7 @@ def save_config(cfg, log_dir):
os.makedirs(log_dir, exist_ok=True)
config_save_path = os.path.join(log_dir, "config.yaml")
with open(config_save_path, "w") as f:
config_dict: Dict = asdict(cfg)
config_dict: dict = asdict(cfg)
yaml.dump(
config_dict,
f,
Expand Down
7 changes: 3 additions & 4 deletions areal/api/workflow_api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations # noqa

from typing import TYPE_CHECKING, Any, Dict
from typing import TYPE_CHECKING, Any

from areal.experimental.openai.types import CompletionWithTokenLogpReward

Expand All @@ -9,10 +9,9 @@


class RolloutWorkflow:

async def arun_episode(
self, engine: "InferenceEngine", data: Dict[str, Any]
) -> Dict[str, Any] | None | Dict[str, CompletionWithTokenLogpReward]:
self, engine: InferenceEngine, data: dict[str, Any]
) -> dict[str, Any] | None | dict[str, CompletionWithTokenLogpReward]:
"""Run a single episode of the workflow.

Note
Expand Down
4 changes: 0 additions & 4 deletions areal/engine/ppo/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
split_padded_tensor_dict_into_mb_list,
)
from areal.utils.functional import (
dynamic_sampling,
gather_logprobs,
gather_logprobs_entropy,
ppo_actor_loss_fn,
Expand Down Expand Up @@ -46,7 +45,6 @@ def __init__(self, config: PPOActorConfig, engine: TrainEngine):
self.mask_no_eos_with_zero = config.mask_no_eos_with_zero

self.temperature = config.temperature
self.dynamic_sampling = config.dynamic_sampling

@torch.no_grad()
def compute_logp(
Expand Down Expand Up @@ -164,8 +162,6 @@ def compute_advantages(self, data: Dict[str, Any]) -> None:
data["logprobs"] = old_logp

def ppo_update(self, data: Dict[str, Any]) -> List[Dict[str, float]]:
if self.dynamic_sampling and len(data["rewards"]) % self.group_size == 0:
data, sampling_stat = dynamic_sampling(data, self.group_size)

attn_mask = data["attention_mask"]
loss_mask = data["loss_mask"]
Expand Down
36 changes: 36 additions & 0 deletions areal/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,42 @@ def concat_padded_tensors(
return result


def truncate_dict_to_batch_size(
data: Dict[str, Any], batch_size: int
) -> Dict[str, Any]:
"""Truncate a dictionary containing tensors and numeric values to specified batch size.

This function handles different value types:
- Tensors: take first batch_size elements along the first dimension
- Numeric values: keep as is (no truncation)
- Other types: keep as is (no truncation)

Args:
data: Dictionary to truncate
batch_size: Target batch size for truncation

Returns:
Truncated dictionary
"""
if not data:
return {}

result = {}

for key, value in data.items():
if torch.is_tensor(value) and len(value.shape) > 0:
# For tensors, take first batch_size elements along first dimension
if value.shape[0] > batch_size:
result[key] = value[:batch_size]
else:
result[key] = value
else:
# For numeric values and other types, keep as is
result[key] = value

return result


def unpack_sequence(
x: torch.Tensor,
cu_seqlens: Optional[torch.Tensor] = None,
Expand Down
Loading
Loading