diff --git a/areal/api/alloc_mode.py b/areal/api/alloc_mode.py index b7f97693b..b452111e3 100644 --- a/areal/api/alloc_mode.py +++ b/areal/api/alloc_mode.py @@ -839,6 +839,7 @@ def parse(self, expression: str): AllocationValidationError: When validation rules are violated ValueError: When parsing fails """ + try: tree = self.parser.parse(expression) transformer = _ParallelStrategyTransformer() diff --git a/areal/api/cli_args.py b/areal/api/cli_args.py index aecccaafa..6039b63cb 100644 --- a/areal/api/cli_args.py +++ b/areal/api/cli_args.py @@ -3,14 +3,9 @@ import os from dataclasses import asdict, dataclass, field from pathlib import Path -from typing import Dict, List import uvloop import yaml - -from areal.utils.pkg_version import is_version_less - -uvloop.install() from hydra import compose as hydra_compose from hydra import initialize as hydra_init from hydra.core.global_hydra import GlobalHydra @@ -18,6 +13,9 @@ from areal.platforms import current_platform from areal.utils import name_resolve, pkg_version +from areal.utils.pkg_version import is_version_less + +uvloop.install() @dataclass @@ -129,11 +127,11 @@ class GenerationHyperparameters: default=1.0, metadata={"help": "Sampling temperature. Higher values increase diversity."}, ) - stop_token_ids: List[int] = field( + stop_token_ids: list[int] = field( default_factory=list, metadata={"help": "Stop generation when encountering these token IDs."}, ) - stop: List[str] | None = field( + stop: list[str] | None = field( default=None, metadata={ "help": "One or multiple stop words. Generation will stop if one of these words is sampled." @@ -232,7 +230,7 @@ class OptimizerConfig: class FSDPWrapPolicy: """Policy configuration for FSDP model layer wrapping. None defaults to wrapping transformer decoder layers defined by transformers.""" - transformer_layer_cls_to_wrap: List[str] | None = field( + transformer_layer_cls_to_wrap: list[str] | None = field( default=None, metadata={"help": "A list of transformer layer names for FSDP to wrap."}, ) @@ -310,7 +308,7 @@ class MegatronEngineConfig: recompute_method: str | None = "uniform" recompute_num_layers: int | None = 1 distribute_saved_activations: bool | None = None - recompute_modules: List[str] | None = None + recompute_modules: list[str] | None = None @dataclass @@ -378,7 +376,7 @@ class TrainEngineConfig: ) lora_rank: int = field(default=32, metadata={"help": "lora rank"}) lora_alpha: int = field(default=16, metadata={"help": "lora alpha"}) - target_modules: List[str] = field( + target_modules: list[str] = field( default_factory=list, metadata={"help": "lora target_modules."}, ) @@ -486,12 +484,10 @@ class PPOActorConfig(TrainEngineConfig): }, ) # Advanced Options - dynamic_sampling: bool = field( - default=False, + dynamic_sampling_strategy: str = field( + default="none", metadata={ - "help": "Enable dynamic sampling (within DAPO). If enabled, groups with the same reward will be masked out. " - "Note that enabling this option will lead to variable batch sizes. If you want to use a constant batch size with dynamic filtering, " - "you should use the `should_accept` parameter in `rollout_batch` and `prepare_batch`." + "help": "Dynamic sampling strategy. Select from `none`, `dynamic` and `static`. Only effective when running DAPO script." }, ) @@ -500,7 +496,7 @@ class PPOActorConfig(TrainEngineConfig): default=False, metadata={"help": "Log statistics for agent trajectories"}, ) - log_agent_stats_keys: List[str] = field( + log_agent_stats_keys: list[str] = field( default_factory=lambda: [], metadata={"help": "Keys for logging agent trajectory statistics"}, ) @@ -574,7 +570,7 @@ def build_args( port, dist_init_addr: str | None = None, ): - args: Dict = conf_as_dict(vllm_config) + args: dict = conf_as_dict(vllm_config) args = dict( host=host, port=port, @@ -608,11 +604,11 @@ def build_cmd( if v is None or v is False or v == "": continue if v is True: - flags.append(f"--{k.replace('_','-')}") + flags.append(f"--{k.replace('_', '-')}") elif isinstance(v, list): - flags.append(f"--{k.replace('_','-')} {' '.join(map(str, v))}") + flags.append(f"--{k.replace('_', '-')} {' '.join(map(str, v))}") else: - flags.append(f"--{k.replace('_','-')} {v}") + flags.append(f"--{k.replace('_', '-')} {v}") return f"python3 -m areal.thirdparty.vllm.areal_vllm_server {' '.join(flags)}" @@ -638,7 +634,7 @@ class SGLangConfig: enable_torch_compile: bool = False torch_compile_max_bs: int = 32 cuda_graph_max_bs: int | None = None - cuda_graph_bs: List[int] | None = None + cuda_graph_bs: list[int] | None = None torchao_config: str = "" enable_nan_detection: bool = False enable_p2p_check: bool = False @@ -667,8 +663,8 @@ class SGLangConfig: # lora enable_lora: bool | None = None max_lora_rank: int | None = None - lora_target_modules: List[str] | None = None - lora_paths: List[str] | None = None + lora_target_modules: list[str] | None = None + lora_paths: list[str] | None = None max_loaded_loras: int = 1 max_loras_per_batch: int = 1 lora_backend: str = "triton" @@ -719,11 +715,11 @@ def build_cmd( if v is None or v is False or v == "": continue if v is True: - flags.append(f"--{k.replace('_','-')}") + flags.append(f"--{k.replace('_', '-')}") elif isinstance(v, list): - flags.append(f"--{k.replace('_','-')} {' '.join(map(str, v))}") + flags.append(f"--{k.replace('_', '-')} {' '.join(map(str, v))}") else: - flags.append(f"--{k.replace('_','-')} {v}") + flags.append(f"--{k.replace('_', '-')} {v}") return f"python3 -m sglang.launch_server {' '.join(flags)}" @staticmethod @@ -738,11 +734,11 @@ def build_args( node_rank: int = 0, ): # Map "all-linear" to "all" - args: Dict = conf_as_dict(sglang_config) + args: dict = conf_as_dict(sglang_config) if sglang_config.enable_multithread_load or sglang_config.enable_fast_load: - assert pkg_version.is_version_equal( - "sglang", "0.5.2" - ), f"Customized model loading requires exact SGLang version 0.5.2" + assert pkg_version.is_version_equal("sglang", "0.5.2"), ( + "Customized model loading requires exact SGLang version 0.5.2" + ) model_loader_extra_config = dict( enable_multithread_load=sglang_config.enable_multithread_load, enable_fast_load=sglang_config.enable_fast_load, @@ -915,8 +911,8 @@ class WandBConfig: job_type: str | None = None group: str | None = None notes: str | None = None - tags: List[str] | None = None - config: Dict | None = None + tags: list[str] | None = None + config: dict | None = None id_suffix: str | None = "train" @@ -926,7 +922,7 @@ class SwanlabConfig: project: str | None = None name: str | None = None - config: Dict | None = None + config: dict | None = None logdir: str | None = None mode: str | None = "disabled" api_key: str | None = os.getenv("SWANLAB_API_KEY", None) @@ -1023,7 +1019,7 @@ class SchedulerConfig: endpoint: str = field(default="http://localhost:8081") deploy_mode: str = field(default="separation") functioncall_service_domain: str = field(default="http://localhost:8080") - reward_functioncall_config: Dict = field(default_factory=dict) + reward_functioncall_config: dict = field(default_factory=dict) reward_model_path: str = field(default="") reward_model_service_url: str = field(default="http://localhost:30000/classify") @@ -1076,7 +1072,7 @@ class SlurmLauncherConfig: default="--mpi=pmi2 -K --chdir $PWD", metadata={"help": "Additional arguments to pass to the srun command."}, ) - additional_bash_cmds: List[str] | None = field( + additional_bash_cmds: list[str] | None = field( default=None, metadata={ "help": "Additional bash commands to setup the container before running " @@ -1244,7 +1240,7 @@ class PPOConfig(GRPOConfig): critic: PPOCriticConfig = field(default_factory=PPOCriticConfig) -def parse_cli_args(argv: List[str]): +def parse_cli_args(argv: list[str]): parser = argparse.ArgumentParser() parser.add_argument( "--config", help="Path to the main configuration file", required=True @@ -1277,7 +1273,7 @@ def to_structured_cfg(cfg, config_cls): return cfg -def load_expr_config(argv: List[str], config_cls): +def load_expr_config(argv: list[str], config_cls): cfg, config_file = parse_cli_args(argv) cfg = to_structured_cfg(cfg, config_cls=config_cls) cfg = OmegaConf.to_object(cfg) @@ -1305,7 +1301,7 @@ def save_config(cfg, log_dir): os.makedirs(log_dir, exist_ok=True) config_save_path = os.path.join(log_dir, "config.yaml") with open(config_save_path, "w") as f: - config_dict: Dict = asdict(cfg) + config_dict: dict = asdict(cfg) yaml.dump( config_dict, f, diff --git a/areal/api/workflow_api.py b/areal/api/workflow_api.py index b551022f0..581dca92b 100644 --- a/areal/api/workflow_api.py +++ b/areal/api/workflow_api.py @@ -1,6 +1,6 @@ from __future__ import annotations # noqa -from typing import TYPE_CHECKING, Any, Dict +from typing import TYPE_CHECKING, Any from areal.experimental.openai.types import CompletionWithTokenLogpReward @@ -9,10 +9,9 @@ class RolloutWorkflow: - async def arun_episode( - self, engine: "InferenceEngine", data: Dict[str, Any] - ) -> Dict[str, Any] | None | Dict[str, CompletionWithTokenLogpReward]: + self, engine: InferenceEngine, data: dict[str, Any] + ) -> dict[str, Any] | None | dict[str, CompletionWithTokenLogpReward]: """Run a single episode of the workflow. Note diff --git a/areal/engine/ppo/actor.py b/areal/engine/ppo/actor.py index ec84b6ef2..19a762584 100644 --- a/areal/engine/ppo/actor.py +++ b/areal/engine/ppo/actor.py @@ -14,7 +14,6 @@ split_padded_tensor_dict_into_mb_list, ) from areal.utils.functional import ( - dynamic_sampling, gather_logprobs, gather_logprobs_entropy, ppo_actor_loss_fn, @@ -46,7 +45,6 @@ def __init__(self, config: PPOActorConfig, engine: TrainEngine): self.mask_no_eos_with_zero = config.mask_no_eos_with_zero self.temperature = config.temperature - self.dynamic_sampling = config.dynamic_sampling @torch.no_grad() def compute_logp( @@ -164,8 +162,6 @@ def compute_advantages(self, data: Dict[str, Any]) -> None: data["logprobs"] = old_logp def ppo_update(self, data: Dict[str, Any]) -> List[Dict[str, float]]: - if self.dynamic_sampling and len(data["rewards"]) % self.group_size == 0: - data, sampling_stat = dynamic_sampling(data, self.group_size) attn_mask = data["attention_mask"] loss_mask = data["loss_mask"] diff --git a/areal/utils/data.py b/areal/utils/data.py index 8d905613f..cd4a8eb70 100644 --- a/areal/utils/data.py +++ b/areal/utils/data.py @@ -221,6 +221,42 @@ def concat_padded_tensors( return result +def truncate_dict_to_batch_size( + data: Dict[str, Any], batch_size: int +) -> Dict[str, Any]: + """Truncate a dictionary containing tensors and numeric values to specified batch size. + + This function handles different value types: + - Tensors: take first batch_size elements along the first dimension + - Numeric values: keep as is (no truncation) + - Other types: keep as is (no truncation) + + Args: + data: Dictionary to truncate + batch_size: Target batch size for truncation + + Returns: + Truncated dictionary + """ + if not data: + return {} + + result = {} + + for key, value in data.items(): + if torch.is_tensor(value) and len(value.shape) > 0: + # For tensors, take first batch_size elements along first dimension + if value.shape[0] > batch_size: + result[key] = value[:batch_size] + else: + result[key] = value + else: + # For numeric values and other types, keep as is + result[key] = value + + return result + + def unpack_sequence( x: torch.Tensor, cu_seqlens: Optional[torch.Tensor] = None, diff --git a/areal/utils/functional.py b/areal/utils/functional.py index 171289428..0a04ac617 100644 --- a/areal/utils/functional.py +++ b/areal/utils/functional.py @@ -1,6 +1,6 @@ import functools import warnings -from typing import Any, Dict, Optional, Tuple +from typing import Any import numpy as np import torch @@ -130,7 +130,7 @@ def gather_logprobs_entropy( @torch.no_grad() def masked_normalization( x: torch.Tensor, - mask: Optional[torch.Tensor] = None, + mask: torch.Tensor | None = None, dim=None, unbiased=False, eps=1e-5, @@ -175,10 +175,10 @@ def ppo_actor_loss_fn( advantages: torch.Tensor, eps_clip: float, loss_mask: torch.Tensor, - eps_clip_higher: Optional[float] = None, - c_clip: Optional[float] = None, - behav_imp_weight_cap: Optional[float] = None, -) -> Tuple[torch.Tensor, Dict]: + eps_clip_higher: float | None = None, + c_clip: float | None = None, + behav_imp_weight_cap: float | None = None, +) -> tuple[torch.Tensor, dict]: """ When decoupled loss is disabled: 1. if recompute logp, both old_logprobs and proximal_logprobs are recomputed logp; @@ -249,9 +249,9 @@ def ppo_critic_loss_fn( old_value: torch.FloatTensor, target_value: torch.FloatTensor, value_eps_clip: float, - loss_mask: Optional[torch.Tensor] = None, + loss_mask: torch.Tensor | None = None, loss_fn_type: str = "mse", -) -> Tuple[torch.Tensor, Dict]: +) -> tuple[torch.Tensor, dict]: """Compute PPO critic loss function given padded batch inputs. There is no shape requirements for the inputs, but they must have the same shape. @@ -311,9 +311,13 @@ def ppo_critic_loss_fn( return value_loss, stat -def dynamic_sampling( - data: Dict[str, Any], group_size: int -) -> Tuple[Dict[str, Any], Dict[str, int]]: +def filter_batch(filter_batch_fn, data: dict[str, Any], group_size: int): + return filter_batch_fn(data, group_size) + + +def filter_batch_fn_DAPO( + data: dict[str, Any], group_size: int +) -> tuple[dict[str, Any], dict[str, int]]: """Filter samples by group when all rewards in a group are equal. Assumes samples of the same group are adjacent in the batch. @@ -354,15 +358,16 @@ def dynamic_sampling( # Expand the group mask to individual samples mask = valid_groups.repeat_interleave(group_size) - # In case all group is filtered out, return the original data (although not gradient in this case) + # In case all groups are filtered out, keep the first group to prevent an infinite loop in the data collection process (though this group will not contribute to the gradient). if not mask.any(): - return data, dict(n_group_kept=0, n_group_filtered=num_groups) + mask[:group_size] = True + valid_groups[0] = True n_group_kept = int(valid_groups.sum().item()) n_group_filtered = int(num_groups - n_group_kept) # Apply mask row-wise across tensors that share the same batch dimension - filtered: Dict[str, Any] = {} + filtered: dict[str, Any] = {} for k, v in data.items(): if torch.is_tensor(v) and v.shape[:1] == (batch_size,): filtered[k] = v[mask] @@ -374,11 +379,11 @@ def dynamic_sampling( # code modified from VERL: https://github.com/volcengine/verl/blob/main/verl/workers/reward_manager/dapo.py def reward_overlong_penalty( - data: Dict[str, Any], + data: dict[str, Any], overlong_tokens: int, overlong_penalty_factor: float, max_response_length: int, -) -> Dict[str, Any]: +) -> dict[str, Any]: reward_score = data["rewards"] input_ids = data["input_ids"] response_lengths = (data["loss_mask"].sum(dim=-1)).long() diff --git a/areal/utils/stats_logger.py b/areal/utils/stats_logger.py index 1ac0009b8..56d16a994 100644 --- a/areal/utils/stats_logger.py +++ b/areal/utils/stats_logger.py @@ -1,7 +1,7 @@ import getpass import os import time -from typing import Dict, List +from typing import Any, Dict, List import swanlab import torch.distributed as dist @@ -146,3 +146,66 @@ def get_log_path(config: StatsLoggerConfig): path = f"{config.fileroot}/logs/{getpass.getuser()}/{config.experiment_name}/{config.trial_name}" os.makedirs(path, exist_ok=True) return path + + +def aggregate_metric_dicts(dicts: List[Dict[str, Any]]) -> Dict[str, Any]: + """Aggregate multiple dictionaries containing numeric values. + + Args: + dicts: List of dictionaries to aggregate + + Returns: + Aggregated dictionary with the same keys + """ + if not dicts: + return {} + + # Get all unique keys from all dictionaries + all_keys = set() + for d in dicts: + all_keys.update(d.keys()) + + result = {} + + for key in all_keys: + # Collect all values for this key + values = [d.get(key) for d in dicts if key in d] + + if not values: + continue + + # Check if all values are numeric (int, float) + if all(isinstance(v, (int, float)) for v in values): + # Sum numeric values + result[key] = sum(values) + + else: + # For mixed or other types, keep as list + result[key] = values + + return result + + +def log_sampling_stats( + sampling_stats_list: list, + epoch: int, + step: int, + global_step: int, + stats_logger: StatsLogger, +): + """Helper function to log sampling statistics for both static and dynamic sampling.""" + sampling_stats = aggregate_metric_dicts(sampling_stats_list) + keep_ratio = float(sampling_stats["n_group_kept"]) / float( + sampling_stats["n_group_filtered"] + sampling_stats["n_group_kept"] + ) + sampling_stats = { + "keep_ratio": keep_ratio, + "filter_ratio": 1.0 - keep_ratio, + } + # Log the sampling statistics + stats_logger.commit( + epoch=epoch, + step=step, + global_step=global_step, + data=sampling_stats, + ) diff --git a/docs/algorithms/dapo.md b/docs/algorithms/dapo.md index 1c31d954d..013f2b603 100644 --- a/docs/algorithms/dapo.md +++ b/docs/algorithms/dapo.md @@ -43,7 +43,9 @@ We only list the different parameters from GRPO here: - `actor.overlong_penalty_factor`: The factor of overlong penalty. - `actor.eps_clip`: The lower bound of clipping, default is `0.2`. - `actor.eps_clip_higher`: The higher bound of clipping. -- `actor.dynamic_sampling`: Define if dynamic sampling should be used. +- `actor.dynamic_sampling_strategy`: Define the dynamic sampling strategy, selected from `none`, `static` and `dynamic`. + + ### Overlong Penalty @@ -51,6 +53,17 @@ Here we briefly introduce the implementation details of DAPO. ![alt text](../figures/dapo_overlong_penalty.jpg) +### Dynamic Sampling Strategy + +- `none`: Turn off dynamic sampling. +- `static`: Only one rollout turn and apply the filter function on it, this may result in variable batch size. +- `dynamic`: Enable Multi-turn rollout to keep the *constant batch size*. + +If `actor.dynamic_sampling_strategy` set to `static` or `dynamic`, a group will be filtered out if all tracjectorys in this group have the same reward. You can customize this by `filter_batch_fn_DAPO_per_group` within the `./areal/utils/functional.py` + + + + ## Example Usage > The algorithm is experimental and may not be stable. diff --git a/docs/cli_reference.md b/docs/cli_reference.md index 176103dc6..1a935a7d7 100644 --- a/docs/cli_reference.md +++ b/docs/cli_reference.md @@ -310,56 +310,56 @@ Configuration for model optimization during training. Configuration for PPO actor model, a subclass of a TrainEngine. -| Parameter | Type | Default | Description | -| ------------------------- | ------------------------------------------------- | --------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `experiment_name` | string | **Required** | - | -| `trial_name` | string | **Required** | - | -| `path` | string | `""` | Path to HuggingFace checkpoint | -| `attn_impl` | string | `"flash_attention_2"` | Attention implementation for huggingface transformers model. **Choices:** `flash_attention_2` | -| `init_from_scratch` | boolean | `False` | Initialize model weights randomly | -| `is_critic` | boolean | `False` | Whether to use a critic/reward model | -| `mb_spec` | [`MicroBatchSpec`](section-micro-batch) | **Required** | - | -| `pad_to_maximum` | boolean | `False` | Whether to pad each microbatch to the length upper bound specified by mb_spec. Can reduce memory fragmentation but slows down training. | -| `disable_dropout` | boolean | `False` | Disable dropout layers during training | -| `gradient_checkpointing` | boolean | `False` | Enable gradient checkpointing | -| `dtype` | string | `"bfloat16"` | Parameter data type. | -| `grad_reduce_dtype` | string | `"float32"` | Gradient reduction data type. | -| `optimizer` | [`OptimizerConfig`](section-optimizer) \| None | `None` | Optimizer configuration. None means no training. | -| `weight_update_mode` | string | `"disk"` | - | -| `backend` | string | `""` | Training backend (refer to documentation) | -| `fsdp` | [`FSDPEngineConfig`](section-fsdp-engine) | **Required** | - | -| `megatron` | [`MegatronEngineConfig`](section-megatron-engine) | **Required** | - | -| `use_lora` | boolean | `False` | Whether to use LoRA. Only support FSDP. Note that should be enabled together with vLLM/SGLang. | -| `lora_rank` | integer | `32` | lora rank | -| `lora_alpha` | integer | `16` | lora alpha | -| `target_modules` | list of string | **Required** | lora target_modules. | -| `peft_type` | string | `"lora"` | peft method type. Only LoRA is supported for now. | -| `group_size` | integer | `1` | Number of sequences in each group | -| `ppo_n_minibatches` | integer | `4` | Number of minibatches for each PPO update | -| `eps_clip` | float | `0.2` | Clipping factor for policy ratio | -| `eps_clip_higher` | float \| None | `None` | Clipping factor (higher value) for policy ratio. Default is None. When eps_clip_higher is set (decoupled), eps_clip will be used as the lower value. | -| `c_clip` | float \| None | `None` | Dual clipping factor for policy ratio, must be > 1.0. None disables dual clipping. | -| `temperature` | float | `1.0` | Temperature during generation. | -| `reward_norm` | [`NormConfig`](section-norm) \| None | `None` | Normalization configuration for rewards | -| `reward_scaling` | float | `1.0` | Reward scaling factor | -| `reward_bias` | float | `0.0` | Reward bias | -| `reward_clip` | float | `20.0` | Maximum absolute value for reward clipping | -| `overlong_reward_penalty` | boolean | `False` | Penalty for overlong sequences. Used within DAPO. | -| `overlong_tokens` | integer \| None | `None` | Number of tokens in the tail that will receive a penalty | -| `overlong_penalty_factor` | float \| None | `None` | Penalty factor for tokens in the tail | -| `mask_no_eos_with_zero` | boolean | `False` | Mask truncated generations (no EOS token) and exclude from training | -| `discount` | float | `1.0` | Discount factor for future rewards | -| `gae_lambda` | float | `1.0` | Lambda parameter for GAE | -| `adv_norm` | [`NormConfig`](section-norm) \| None | `None` | Normalization configuration for advantages. | -| `kl_ctl` | float | `0.1` | KL divergence coefficient | -| `kl_estimator` | string | `"k1"` | KL divergence estimator **Choices:** `k1`, `k2`, `k3` | -| `recompute_logprob` | boolean | `False` | Recompute log probability and replace the log probability returned by inference. | -| `use_decoupled_loss` | boolean | `False` | Use the decoupled loss. Implicitly enables recompute_logprob. | -| `behav_imp_weight_cap` | float \| None | `None` | Filter out tokens where behav_imp_weight exceeds behav_imp_weight_cap when computing loss. Must be > 1.0. use_decoupled_loss must be true. | -| `dynamic_sampling` | boolean | `False` | Enable dynamic sampling (within DAPO). If enabled, groups with the same reward will be masked out. Note that enabling this option will lead to variable batch sizes. If you want to use a constant batch size with dynamic filtering, you should use the `should_accept` parameter in `rollout_batch` and `prepare_batch`. | -| `log_agent_stats` | boolean | `False` | Log statistics for agent trajectories | -| `log_agent_stats_keys` | list of string | **Required** | Keys for logging agent trajectory statistics | -| `max_new_tokens` | integer | `1024` | Maximum number of new tokens to generate | +| Parameter | Type | Default | Description | +| --------------------------- | ------------------------------------------------- | --------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------- | +| `experiment_name` | string | **Required** | - | +| `trial_name` | string | **Required** | - | +| `path` | string | `""` | Path to HuggingFace checkpoint | +| `attn_impl` | string | `"flash_attention_2"` | Attention implementation for huggingface transformers model. **Choices:** `flash_attention_2` | +| `init_from_scratch` | boolean | `False` | Initialize model weights randomly | +| `is_critic` | boolean | `False` | Whether to use a critic/reward model | +| `mb_spec` | [`MicroBatchSpec`](section-micro-batch) | **Required** | - | +| `pad_to_maximum` | boolean | `False` | Whether to pad each microbatch to the length upper bound specified by mb_spec. Can reduce memory fragmentation but slows down training. | +| `disable_dropout` | boolean | `False` | Disable dropout layers during training | +| `gradient_checkpointing` | boolean | `False` | Enable gradient checkpointing | +| `dtype` | string | `"bfloat16"` | Parameter data type. | +| `grad_reduce_dtype` | string | `"float32"` | Gradient reduction data type. | +| `optimizer` | [`OptimizerConfig`](section-optimizer) \| None | `None` | Optimizer configuration. None means no training. | +| `weight_update_mode` | string | `"disk"` | - | +| `backend` | string | `""` | Training backend (refer to documentation) | +| `fsdp` | [`FSDPEngineConfig`](section-fsdp-engine) | **Required** | - | +| `megatron` | [`MegatronEngineConfig`](section-megatron-engine) | **Required** | - | +| `use_lora` | boolean | `False` | Whether to use LoRA. Only support FSDP. Note that should be enabled together with vLLM/SGLang. | +| `lora_rank` | integer | `32` | lora rank | +| `lora_alpha` | integer | `16` | lora alpha | +| `target_modules` | list of string | **Required** | lora target_modules. | +| `peft_type` | string | `"lora"` | peft method type. Only LoRA is supported for now. | +| `group_size` | integer | `1` | Number of sequences in each group | +| `ppo_n_minibatches` | integer | `4` | Number of minibatches for each PPO update | +| `eps_clip` | float | `0.2` | Clipping factor for policy ratio | +| `eps_clip_higher` | float \| None | `None` | Clipping factor (higher value) for policy ratio. Default is None. When eps_clip_higher is set (decoupled), eps_clip will be used as the lower value. | +| `c_clip` | float \| None | `None` | Dual clipping factor for policy ratio, must be > 1.0. None disables dual clipping. | +| `temperature` | float | `1.0` | Temperature during generation. | +| `reward_norm` | [`NormConfig`](section-norm) \| None | `None` | Normalization configuration for rewards | +| `reward_scaling` | float | `1.0` | Reward scaling factor | +| `reward_bias` | float | `0.0` | Reward bias | +| `reward_clip` | float | `20.0` | Maximum absolute value for reward clipping | +| `overlong_reward_penalty` | boolean | `False` | Penalty for overlong sequences. Used within DAPO. | +| `overlong_tokens` | integer \| None | `None` | Number of tokens in the tail that will receive a penalty | +| `overlong_penalty_factor` | float \| None | `None` | Penalty factor for tokens in the tail | +| `mask_no_eos_with_zero` | boolean | `False` | Mask truncated generations (no EOS token) and exclude from training | +| `discount` | float | `1.0` | Discount factor for future rewards | +| `gae_lambda` | float | `1.0` | Lambda parameter for GAE | +| `adv_norm` | [`NormConfig`](section-norm) \| None | `None` | Normalization configuration for advantages. | +| `kl_ctl` | float | `0.1` | KL divergence coefficient | +| `kl_estimator` | string | `"k1"` | KL divergence estimator **Choices:** `k1`, `k2`, `k3` | +| `recompute_logprob` | boolean | `False` | Recompute log probability and replace the log probability returned by inference. | +| `use_decoupled_loss` | boolean | `False` | Use the decoupled loss. Implicitly enables recompute_logprob. | +| `behav_imp_weight_cap` | float \| None | `None` | Filter out tokens where behav_imp_weight exceeds behav_imp_weight_cap when computing loss. Must be > 1.0. use_decoupled_loss must be true. | +| `dynamic_sampling_strategy` | string | `"none"` | Dynamic sampling strategy. Select from `none`, `dynamic` and `static`. Only effective when running DAPO script. | +| `log_agent_stats` | boolean | `False` | Log statistics for agent trajectories | +| `log_agent_stats_keys` | list of string | **Required** | Keys for logging agent trajectory statistics | +| `max_new_tokens` | integer | `1024` | Maximum number of new tokens to generate | (section-ppo-critic)= @@ -712,7 +712,7 @@ Configuration for SwanLab experiment tracking and monitoring. | --------- | -------------- | ------------ | ----------- | | `project` | string \| None | `None` | - | | `name` | string \| None | `None` | - | -| `config` | `Dict` \| None | `None` | - | +| `config` | `dict` \| None | `None` | - | | `logdir` | string \| None | `None` | - | | `mode` | string \| None | `"disabled"` | - | | `api_key` | string \| None | `None` | - | @@ -745,7 +745,7 @@ Configuration for Weights & Biases experiment tracking. | `group` | string \| None | `None` | - | | `notes` | string \| None | `None` | - | | `tags` | list of string \| None | `None` | - | -| `config` | `Dict` \| None | `None` | - | +| `config` | `dict` \| None | `None` | - | | `id_suffix` | string \| None | `"train"` | - | (section-distributed-data-parallel)= @@ -808,6 +808,6 @@ Configuration for worker scheduling. Used in the single-controller mode. Experim | `endpoint` | string | `"http://localhost:8081"` | - | | `deploy_mode` | string | `"separation"` | - | | `functioncall_service_domain` | string | `"http://localhost:8080"` | - | -| `reward_functioncall_config` | `Dict` | **Required** | - | +| `reward_functioncall_config` | `dict` | **Required** | - | | `reward_model_path` | string | `""` | - | | `reward_model_service_url` | string | `"http://localhost:30000/classify"` | - | diff --git a/examples/experimental/dapo/gsm8k_dapo.py b/examples/experimental/dapo/gsm8k_dapo.py index f7d93df35..27fdd3271 100644 --- a/examples/experimental/dapo/gsm8k_dapo.py +++ b/examples/experimental/dapo/gsm8k_dapo.py @@ -13,15 +13,19 @@ from areal.platforms import current_platform from areal.utils import seeding, stats_tracker from areal.utils.data import ( + concat_padded_tensors, cycle_dataloader, + get_batch_size, + truncate_dict_to_batch_size, ) from areal.utils.dataloader import create_dataloader from areal.utils.device import log_gpu_stats from areal.utils.evaluator import Evaluator +from areal.utils.functional import filter_batch, filter_batch_fn_DAPO from areal.utils.hf_utils import load_hf_tokenizer from areal.utils.recover import RecoverHandler from areal.utils.saver import Saver -from areal.utils.stats_logger import StatsLogger +from areal.utils.stats_logger import StatsLogger, log_sampling_stats from areal.workflow.rlvr import RLVRWorkflow @@ -35,6 +39,7 @@ def main(args): config, _ = load_expr_config(args, GRPOConfig) config: GRPOConfig + assert config.actor.dynamic_sampling_strategy in ["none", "static", "dynamic"] rank = int(os.getenv("RANK")) tokenizer = load_hf_tokenizer(config.tokenizer_path) @@ -55,6 +60,10 @@ def main(args): split="test", dataset_config=config.valid_dataset, tokenizer=tokenizer ) + # Create dataset and dataloaders + train_loader_batch_size = ( + config.train_dataset.batch_size // actor.data_parallel_world_size + ) train_dataloader = create_dataloader( train_dataset, rank=actor.data_parallel_rank, @@ -117,7 +126,7 @@ def main(args): ), ) - # Run training. + # Run training saver = Saver(config.saver, ft_spec) stats_logger = StatsLogger(config, ft_spec) evaluator = Evaluator(config.evaluator, ft_spec) @@ -143,6 +152,7 @@ def main(args): max_steps = total_epochs * steps_per_epoch data_generator = cycle_dataloader(train_dataloader) + for global_step in range(start_step, max_steps): epoch = global_step // steps_per_epoch step = global_step % steps_per_epoch @@ -152,22 +162,65 @@ def main(args): epoch_step=step, steps_per_epoch=steps_per_epoch, ) + # Initialize batch collection with stats_tracker.record_timing("rollout"): - if config.async_training: - batch = actor.prepare_batch( - train_dataloader, - granularity=actor.config.group_size, - workflow=workflow, - should_accept=lambda sample: True, - ) - else: - batch = actor.rollout_batch( - next(data_generator), - granularity=actor.config.group_size, - workflow=workflow, - should_accept=lambda sample: True, - ) + collected_batches, sampling_stats, collected_batches_size = [], [], 0 + while True: + if config.async_training: + new_batch = actor.prepare_batch( + train_dataloader, + granularity=actor.config.group_size, + workflow=workflow, + should_accept=lambda sample: True, + ) + else: + new_batch = actor.rollout_batch( + next(data_generator), + granularity=actor.config.group_size, + workflow=workflow, + should_accept=lambda sample: True, + ) + + # Collect the batch and process it immediately + if config.actor.dynamic_sampling_strategy in ["static", "dynamic"]: + # Filter the current batch by groups + filtered_batch, sampling_stat = filter_batch( + filter_batch_fn_DAPO, new_batch, config.actor.group_size + ) + sampling_stats.append(sampling_stat) + + if config.actor.dynamic_sampling_strategy == "static": + # Statistic sampling: No need to refill for static sampling, result in smaller(variant) batch size + batch = filtered_batch + # Log sampling statistics for static sampling + log_sampling_stats( + sampling_stats, epoch, step, global_step, stats_logger + ) + break + elif config.actor.dynamic_sampling_strategy == "dynamic": + # Dynamic sampling: keep collecting batches until we reach the target batch size + # Add filtered batch to collection + collected_batches.append(filtered_batch) + collected_batches_size += get_batch_size(filtered_batch) + + # Check if we have collected enough samples + if collected_batches_size >= train_loader_batch_size: + aggregated_batch = concat_padded_tensors(collected_batches) + # Log sampling statistics for dynamic sampling + log_sampling_stats( + sampling_stats, epoch, step, global_step, stats_logger + ) + # Truncate batch to train_batch_size + batch = truncate_dict_to_batch_size( + data=aggregated_batch, + batch_size=train_loader_batch_size, + ) + break + else: + # For non-dynamic sampling, just use the current batch + batch = new_batch + break if config.actor.recompute_logprob or config.actor.use_decoupled_loss: with stats_tracker.record_timing("recompute_logp"): @@ -223,8 +276,6 @@ def main(args): def evaluate_fn(): if actor.is_data_parallel_head(): - # Stats are logged in workflow - # and will be exported later cnt = 0 for data in valid_dataloader: for item in data: diff --git a/examples/experimental/dapo/gsm8k_dapo.yaml b/examples/experimental/dapo/gsm8k_dapo.yaml index 8b706563b..f075f1279 100644 --- a/examples/experimental/dapo/gsm8k_dapo.yaml +++ b/examples/experimental/dapo/gsm8k_dapo.yaml @@ -4,7 +4,7 @@ trial_name: trial0 seed: 1 total_train_epochs: 10 tokenizer_path: ${actor.path} -async_training: true +async_training: false cluster: n_nodes: 1 @@ -14,7 +14,7 @@ cluster: type: nfs nfs_record_root: /tmp/areal/name_resolve -allocation_mode: sglang.d4+fsdp.d4 +allocation_mode: sglang:d4+fsdp:d4 rollout: experiment_name: ${experiment_name} @@ -67,7 +67,6 @@ actor: recompute_logprob: true use_decoupled_loss: true behav_imp_weight_cap: 5.0 - dynamic_sampling: true reward_norm: mean_level: group std_level: group @@ -102,7 +101,7 @@ sglang: # datasets train_dataset: - batch_size: 256 + batch_size: 16 shuffle: true pin_memory: true num_workers: 4 diff --git a/examples/experimental/dr.grpo/gsm8k_drgrpo.yaml b/examples/experimental/dr.grpo/gsm8k_drgrpo.yaml index 9e90b9a6a..3f1817c09 100644 --- a/examples/experimental/dr.grpo/gsm8k_drgrpo.yaml +++ b/examples/experimental/dr.grpo/gsm8k_drgrpo.yaml @@ -63,7 +63,6 @@ actor: recompute_logprob: true use_decoupled_loss: true behav_imp_weight_cap: 5.0 - dynamic_sampling: false reward_norm: mean_level: group std_level: null diff --git a/examples/experimental/lite_ppo/gsm8k_liteppo.yaml b/examples/experimental/lite_ppo/gsm8k_liteppo.yaml index 7bae82c94..bf6452385 100644 --- a/examples/experimental/lite_ppo/gsm8k_liteppo.yaml +++ b/examples/experimental/lite_ppo/gsm8k_liteppo.yaml @@ -63,7 +63,6 @@ actor: recompute_logprob: true use_decoupled_loss: true behav_imp_weight_cap: 5.0 - dynamic_sampling: false reward_norm: mean_level: group std_level: batch diff --git a/examples/lora/gsm8k_grpo_lora.yaml b/examples/lora/gsm8k_grpo_lora.yaml index ce114574d..2e807d434 100644 --- a/examples/lora/gsm8k_grpo_lora.yaml +++ b/examples/lora/gsm8k_grpo_lora.yaml @@ -64,7 +64,6 @@ actor: recompute_logprob: true use_decoupled_loss: true behav_imp_weight_cap: 5.0 - dynamic_sampling: false reward_norm: mean_level: group std_level: group diff --git a/examples/math/boba_grpo_vllm.yaml b/examples/math/boba_grpo_vllm.yaml index 9b0654225..afca6da00 100644 --- a/examples/math/boba_grpo_vllm.yaml +++ b/examples/math/boba_grpo_vllm.yaml @@ -63,7 +63,6 @@ actor: recompute_logprob: true use_decoupled_loss: true behav_imp_weight_cap: 5.0 - dynamic_sampling: false adv_norm: mean_level: batch std_level: batch diff --git a/examples/math/deprecated_gsm8k_grpo.yaml b/examples/math/deprecated_gsm8k_grpo.yaml index 33d98647d..efc91502a 100644 --- a/examples/math/deprecated_gsm8k_grpo.yaml +++ b/examples/math/deprecated_gsm8k_grpo.yaml @@ -66,7 +66,6 @@ actor: recompute_logprob: true use_decoupled_loss: true behav_imp_weight_cap: 5.0 - dynamic_sampling: false adv_norm: # Deprecated configuration, change to reward_norm and apply another adv norm mean_level: group std_level: group diff --git a/examples/math/gsm8k_grpo.yaml b/examples/math/gsm8k_grpo.yaml index 85a119247..62a36b20a 100644 --- a/examples/math/gsm8k_grpo.yaml +++ b/examples/math/gsm8k_grpo.yaml @@ -63,7 +63,7 @@ actor: recompute_logprob: true use_decoupled_loss: true behav_imp_weight_cap: 5.0 - dynamic_sampling: false + dynamic_sampling_strategy: none reward_norm: mean_level: group std_level: group diff --git a/examples/math/gsm8k_ppo.yaml b/examples/math/gsm8k_ppo.yaml index 01cc80662..e44056377 100644 --- a/examples/math/gsm8k_ppo.yaml +++ b/examples/math/gsm8k_ppo.yaml @@ -63,7 +63,6 @@ actor: recompute_logprob: true use_decoupled_loss: true behav_imp_weight_cap: 5.0 - dynamic_sampling: false max_new_tokens: ${gconfig.max_new_tokens} critic: diff --git a/examples/math/gsm8k_reinforce.yaml b/examples/math/gsm8k_reinforce.yaml index a8d043a65..5c056eacb 100644 --- a/examples/math/gsm8k_reinforce.yaml +++ b/examples/math/gsm8k_reinforce.yaml @@ -64,7 +64,6 @@ actor: recompute_logprob: true use_decoupled_loss: true behav_imp_weight_cap: 5.0 - dynamic_sampling: false adv_norm: mean_level: batch std_level: batch diff --git a/examples/math/gsm8k_reinforce_baseline.yaml b/examples/math/gsm8k_reinforce_baseline.yaml index 3fe1c8f47..eb3534436 100644 --- a/examples/math/gsm8k_reinforce_baseline.yaml +++ b/examples/math/gsm8k_reinforce_baseline.yaml @@ -64,7 +64,6 @@ actor: recompute_logprob: true use_decoupled_loss: true behav_imp_weight_cap: 5.0 - dynamic_sampling: false reward_norm: mean_level: group std_level: null diff --git a/examples/math/gsm8k_rloo.yaml b/examples/math/gsm8k_rloo.yaml index 99b12c784..a4f41a4c1 100644 --- a/examples/math/gsm8k_rloo.yaml +++ b/examples/math/gsm8k_rloo.yaml @@ -63,7 +63,6 @@ actor: recompute_logprob: true use_decoupled_loss: true behav_imp_weight_cap: 5.0 - dynamic_sampling: false reward_norm: mean_level: group mean_leave1out: true diff --git a/examples/multi-turn-math/config.yaml b/examples/multi-turn-math/config.yaml index 27fc735ab..d4d01aac1 100644 --- a/examples/multi-turn-math/config.yaml +++ b/examples/multi-turn-math/config.yaml @@ -63,7 +63,6 @@ actor: recompute_logprob: true use_decoupled_loss: true behav_imp_weight_cap: 5.0 - dynamic_sampling: false adv_norm: mean_level: batch std_level: batch diff --git a/examples/tir/tir_math_config.yaml b/examples/tir/tir_math_config.yaml index 9b8473fe4..06323609e 100644 --- a/examples/tir/tir_math_config.yaml +++ b/examples/tir/tir_math_config.yaml @@ -60,7 +60,6 @@ actor: recompute_logprob: true use_decoupled_loss: true behav_imp_weight_cap: 5.0 - dynamic_sampling: false reward_norm: mean_level: group std_level: group diff --git a/examples/vlm/clevr_count_70k_grpo.yaml b/examples/vlm/clevr_count_70k_grpo.yaml index b64aa5bc8..bdfbda561 100644 --- a/examples/vlm/clevr_count_70k_grpo.yaml +++ b/examples/vlm/clevr_count_70k_grpo.yaml @@ -63,7 +63,6 @@ actor: recompute_logprob: true use_decoupled_loss: true behav_imp_weight_cap: 5.0 - dynamic_sampling: false reward_norm: mean_level: group std_level: group diff --git a/recipe/AEnt/actor.py b/recipe/AEnt/actor.py index 03b808315..852351d39 100644 --- a/recipe/AEnt/actor.py +++ b/recipe/AEnt/actor.py @@ -11,11 +11,12 @@ from areal.utils import stats_tracker from areal.utils.data import split_padded_tensor_dict_into_mb_list from areal.utils.functional import ( - dynamic_sampling, + filter_batch, gather_logprobs, gather_logprobs_entropy, ppo_actor_loss_fn, reward_overlong_penalty, + filter_batch_fn_DAPO ) from recipe.AEnt.aent_args import AEntPPOActorConfig from recipe.AEnt.functional import gather_logprobs_clamped_entropy @@ -39,8 +40,10 @@ def __init__(self, config: AEntPPOActorConfig, engine: TrainEngine): def aent_ppo_update( self, data: TensorDict, global_step: int ) -> List[Dict[str, float]]: - if self.dynamic_sampling and len(data["rewards"]) % self.group_size == 0: - data, sampling_stat = dynamic_sampling(data, self.group_size) + if self.config.dynamic_sampling_strategy != "none" and len(data["rewards"]) % self.group_size == 0: + data, sampling_stat = filter_batch( + filter_batch_fn_DAPO, data, self.group_size + ) attn_mask = data["attention_mask"] loss_mask = data["loss_mask"] diff --git a/recipe/AEnt/configs/gsm8k_aent_grpo.yaml b/recipe/AEnt/configs/gsm8k_aent_grpo.yaml index bf307b5a2..9d2c9d8f9 100644 --- a/recipe/AEnt/configs/gsm8k_aent_grpo.yaml +++ b/recipe/AEnt/configs/gsm8k_aent_grpo.yaml @@ -75,7 +75,7 @@ actor: recompute_logprob: true use_decoupled_loss: true behav_imp_weight_cap: 5.0 - dynamic_sampling: false + dynamic_sampling_strategy: none adv_norm: mean_level: group std_level: group