diff --git a/areal/api/cli_args.py b/areal/api/cli_args.py index cd36f7a2d..f5f11feca 100644 --- a/areal/api/cli_args.py +++ b/areal/api/cli_args.py @@ -8,6 +8,8 @@ import uvloop import yaml +from areal.utils.pkg_version import is_version_less + uvloop.install() from hydra import compose as hydra_compose from hydra import initialize as hydra_init @@ -55,6 +57,18 @@ class NormConfig: group_size: int = field( default=1, metadata={"help": "Group size for group-level normalization"} ) + adv_norm_mode: str = field( + default="native", + metadata={ + "help": "native or mix. native is the normal z-score normalization. for mix, the normal z-score and mean-base z-score normalization will be calculated and aggregated (more info please refer to the paper of MAPO)." + }, + ) + reward_norm_mode: str = field( + default="native", + metadata={ + "help": "Mode for reward normalization. Currently only 'native' is supported." + }, + ) @dataclass @@ -632,6 +646,8 @@ def build_cmd( # convert to flags flags = [] for k, v in args.items(): + if is_version_less("sglang", "0.4.10.post2") and "max_loaded_loras" in k: + continue if v is None or v is False or v == "": continue if v is True: @@ -640,6 +656,7 @@ def build_cmd( flags.append(f"--{k.replace('_','-')} {' '.join(map(str, v))}") else: flags.append(f"--{k.replace('_','-')} {v}") + return f"python3 -m sglang.launch_server {' '.join(flags)}" @staticmethod diff --git a/areal/engine/ppo/actor.py b/areal/engine/ppo/actor.py index 673f35a1a..bcc901f4b 100644 --- a/areal/engine/ppo/actor.py +++ b/areal/engine/ppo/actor.py @@ -9,7 +9,8 @@ from areal.utils import stats_tracker from areal.utils.data import ( KLEstimator, - Normalization, + get_adv_norm, + get_reward_norm, split_padded_tensor_dict_into_mb_list, ) from areal.utils.functional import ( @@ -35,10 +36,8 @@ def __init__(self, config: PPOActorConfig, engine: TrainEngine): self.kl_ctl = config.kl_ctl self.kl_estimator = KLEstimator(config.kl_estimator) - self.adv_norm = Normalization(config.adv_norm) if config.adv_norm else None - self.reward_norm = ( - Normalization(config.reward_norm) if config.reward_norm else None - ) + self.adv_norm = get_adv_norm(config) + self.reward_norm = get_reward_norm(config) self.discount = config.discount self.gae_lambda = config.gae_lambda diff --git a/areal/utils/data.py b/areal/utils/data.py index 8d905613f..f67fe764c 100644 --- a/areal/utils/data.py +++ b/areal/utils/data.py @@ -11,7 +11,7 @@ from einops import rearrange from torchdata.stateful_dataloader import StatefulDataLoader -from areal.api.cli_args import MicroBatchSpec, NormConfig +from areal.api.cli_args import MicroBatchSpec, NormConfig, PPOActorConfig from areal.platforms import current_platform from areal.utils import datapack, logging @@ -1070,6 +1070,7 @@ def cycle_dataloader(dataloader: StatefulDataLoader): g = iter(dataloader) +# base native normalization implementation (for both reward and adv norm) class Normalization: """ Adaptive normalization with different levels. @@ -1108,7 +1109,11 @@ def __call__( loss_mask: Optional[torch.Tensor] = None, high_precision: bool = True, reduce_group=None, + calculation_base: str = "deviation", ) -> torch.Tensor: + + # x can be advantage or reward in shape [bs*self.group_size, max_tokens] + bs = x.size(0) eps = self.eps @@ -1200,8 +1205,15 @@ def __call__( std = torch.ones_like(x) eps = 0.0 + assert calculation_base in [ + "mean", + "deviation", + ], "calculation_base must be either mean or deviation" + base = std if calculation_base == "deviation" else mean + # Ensure stability + base += eps # Normalize - return (x_centered / (std + eps)).float() + return (x_centered / base).float() @staticmethod def _compute_mean( @@ -1362,3 +1374,116 @@ def _compute_approx_kl( if apply_clamp: log_ratio = log_ratio.clamp(min=-10, max=10) return log_ratio + + +# the mixed adv norm implementation to paper MAPO, derived from base native normalization implementation +class MAPOAdvNorm(Normalization): + def __call__(self, advantages, loss_mask=None, **kwargs) -> torch.Tensor: + # Calculate the unique number of elements in advantages Tensor,exclude element of 0 (because 0 means adv over pad_token) + + # deviation_base_norm shape [batch_size*group_size, max_token] + deviation_base_norm = super().__call__( + advantages, loss_mask=loss_mask, calculation_base="deviation", **kwargs + ) + + unique_elements = torch.unique(advantages[advantages != 0]).numel() + + if unique_elements >= 3 or unique_elements <= 1: + # means all advantages are same but not 0 + if unique_elements >= 3: + logger.warning( + msg=f"The MAPO only support reward modeling in a binary, but detected {unique_elements} unique elements in advantages Tensor. Please check: " + f"1. the definition of reward_fun: return the binary number " + f"2. overlong_reward_panalty set to false" + ) + # means all advantages are same but not 0 + else: + logger.info( + ( + f"the advantage are all same in the batch, please check your reward function" + ) + ) + + logger.info((f"falling back to native advantage normalization")) + # fall back to native implementation is ok + return super().__call__( + advantages, loss_mask=loss_mask, calculation_base="deviation", **kwargs + ) + + # the 'unique_upper_value' means the reward of success trajectory + unique_upper_value, unique_lower_value = ( + max(unique_elements).item(), + min(unique_elements).item(), + ) + + assert unique_elements <= 2, ( + f"The MAPO only support reward modeling in a binary, but detected {unique_elements} unique elements in advantages Tensor. Please check: " + f"1. the definition of reward_fun: return the binary number " + f"2. overlong_reward_panalty set to false" + ) + + # mean_base_norm shape [batch_size*group_size, max_token] + mean_base_norm = super().__call__( + advantages, loss_mask=loss_mask, calculation_base="mean", **kwargs + ) + + bs, max_token = int(advantages.shape[0] / self.group_size), advantages.shape[-1] + + # since the advantages is same within same trajectory, we can get the trajectory_level advantage from first token + # base on assumption that the advantage on last dim are totally same + + advantages_ = advantages[:, 0] # advantages shape [batch_size*group_size] + + advantages_ = advantages_.reshape( + bs, self.group_size + ) # advantages shape [batch_size, group_size] + + # the number of sucess trajectory within each group and batch + success_trajectory_nums_per_group = (advantages_ == unique_upper_value).sum( + dim=1 + ) # success_trajectory_nums shape [batch_size] + # the number of total trajectory within each group + total_trajectory_nums_per_group = torch.tensor([self.group_size] * bs).to( + device=success_trajectory_nums_per_group.device, + dtype=success_trajectory_nums_per_group.dtype, + ) # total_trajectory_nums shape [batch_size] + # the probability of success trajectory within each group and batch + p = success_trajectory_nums_per_group / total_trajectory_nums_per_group + + # trajectory_reweight shape [batch_size], represent the reweight of tragetories + # p==0: all trajectory are fail -> trajectory_reweight==1-> only use mean_base_norm + # p==1: all trajectory are success -> trajectory_reweight==1-> only use mean_base_norm + # p==0.5: half trajectory are success -> trajectory_reweight==0 ->only use deviation_base_norm + trajectory_reweight = 1 - (4 * p * (1 - p)) + + # trajectory_reweight shape to expand each_token of advantages + # trajectory_reweight [batch_size]->[batch_size*group_size]->[batch_size*group_size, max_token],each trajectory has same reweight for each token. + # i.e. trajectory_reweight granularity: group-level-> trajectory-level->token-level + trajectory_reweight = ( + trajectory_reweight.repeat_interleave(self.group_size) + .unsqueeze(-1) + .expand(-1, max_token) + ) + # in this case 'trajectory_reweight' & 'deviation_base_norm' & 'mean_base_norm' have the same granularity + # torch auto broadcasting will automatically expand the dimension to do the calculation + return ( + 1 - trajectory_reweight + ) * deviation_base_norm + trajectory_reweight * mean_base_norm + + +def get_reward_norm(config: PPOActorConfig): + if config.reward_norm: + return Normalization(config.reward_norm) + else: + return None + + +def get_adv_norm(config: PPOActorConfig): + if config.adv_norm: + if config.adv_norm.adv_norm_mode == "mix": + assert ( + config.reward_bias == 0.0 + ), "When using mixed adv norm (MAPO), reward_bias should be 0.0 to ensure binary reward." + return MAPOAdvNorm(config.adv_norm) + else: + return Normalization(config.adv_norm) diff --git a/areal/utils/functional.py b/areal/utils/functional.py index 11421b8ee..7c4bccdad 100644 --- a/areal/utils/functional.py +++ b/areal/utils/functional.py @@ -107,6 +107,7 @@ def masked_normalization( high_precision=True, all_reduce=True, reduce_group=None, + calculation_base: str = "deviation", ): dtype = torch.float64 if high_precision else torch.float32 x = x.to(dtype) @@ -135,7 +136,17 @@ def masked_normalization( var = meansq - mean**2 if unbiased: var *= factor / (factor - 1) - return ((x - mean) / (var.sqrt() + eps)).float() + assert calculation_base in [ + "mean", + "deviation", + ], "calculation_base must be either mean or deviation" + + std = var.sqrt() + base = std if calculation_base == "deviation" else mean + # Ensure stability + base = base + eps + # Normalize + return ((x - mean) / base).float() def ppo_actor_loss_fn( diff --git a/docs/_toc.yml b/docs/_toc.yml index 238605ef8..d77e3b2e1 100644 --- a/docs/_toc.yml +++ b/docs/_toc.yml @@ -42,6 +42,7 @@ parts: - file: algorithms/dapo - file: algorithms/dr.GRPO - file: algorithms/litePPO + - file: algorithms/mapo - caption: Customization (Legacy) chapters: - file: legacy/customization/dataset diff --git a/docs/algorithms/mapo.md b/docs/algorithms/mapo.md new file mode 100644 index 000000000..8b60bb0cb --- /dev/null +++ b/docs/algorithms/mapo.md @@ -0,0 +1,53 @@ +# Mixed Advantage Policy Optimization (MAPO) + +Last updated: Sep 27, 2025 + +Author: [Ziyi ZENG](https://github.com/ZiyiTsang) + +![alt text](../figures/MAPO.jpg) + +This paper introduces Mixed Advantage Policy Optimization (MAPO), an improved Group Relative Policy Optimization (GRPO) strategy designed to enhance the reasoning performance of foundation models. While GRPO has been effective in post-training foundation models for reasoning tasks, it suffers from "advantage reversion" and "advantage mirror" problems, which lead to an unreasonable allocation of advantage across different query samples. MAPO addresses these limitations by introducing the concept of "trajectory certainty" and proposing an "Advantage Percent Deviation" (APD) for high-certainty trajectories. Furthermore, it dynamically reweights the advantage function based on trajectory certainty through "Trajectory Certainty Reweight" (TCR). This adaptive approach allows MAPO to configure the advantage function to account for sample-specific characteristics, thereby mitigating the shortcomings of prior advantage function formulations and producing more stable and accurate reasoning performance across diverse tasks. + +The overall surrogate objective is: + + +$$\mathcal{J}_{\mathrm{GRPO}}(\theta)=\mathbb{E}_{q\sim\rho_{Q}}\mathbb{E}_{o\sim\pi_{old}(\cdot|q)}\left[\frac{1}{G}\sum_{i}^{G}f_{\epsilon}\left(\frac{\pi_{\theta}(o_{i}|q)}{\pi_{old}(o_{i}|q)},\hat{\Lambda}_{i}\right)\right]-\beta\mathbb{D}_{KL}[\pi_{\theta}||\pi_{ref}],$$ +where: +$$f_\epsilon(x,y)=\min(xy,\mathrm{clip}(x,1-\epsilon,1+\epsilon)y)$$ + +$$\lambda(p)=1-4p(1-p)\in[0,1]\quad(p\in[0,1])$$ + +$$\hat{A}_i^*=(1-\lambda(p))*\underbrace{\frac{r_i-\mu}{\sigma}}_{\text{Deviation-based}}+\lambda(p)*\underbrace{\frac{r_i-\mu}{\mu}}_{\text{Mean-based}}.$$ + + +For more details: + +- AReal Detail: [Paper of AReal](https://arxiv.org/abs/2505.24298) + +- MAPO Detail: [Paper of MAPO](https://arxiv.org/abs/2509.18849v3) + +## Algorithm Core Parameters + +- `actor.adv_norm.aggregation_mode`: the implementation of adv_norm. 'native' is the z-score normalization used by GRPO, while 'mix' is the implementation for MAPO. + +## Notice +For MAPO implementation, following constraints should be met: + +1. 'reward_function' should return binary result of any value. High value represents the successful trajectory, while the lower value represent the fail trajectory. +2. the 'overlong_reward_panelty' should be disable + + +## Example Usage + +We recommend to change the parameter within the configuration file +(i.e. gsm8k_mapo.yaml). + +| Backend | CMD | +| --------- | -------------------------------------------------------------------------------------------------------------------------------- | +| **local** | `python3 -m areal.launcher.local examples/experimental/mapo/gsm8k_mapo.py --config examples/experimental/mapo/gsm8k_mapo.yaml --` | +| **ray** | `python3 -m areal.launcher.ray examples/experimental/mapo/gsm8k_mapo.py --config examples/experimental/mapo/gsm8k_mapo.yaml --` | +| **slurm** | `python3 -m areal.launcher.slurm examples/experimental/mapo/gsm8k_mapo.py --config examples/experimental/mapo/gsm8k_mapo.yaml --` | + +## Baselines + +We still lack baseline, welcome to contribute! diff --git a/docs/cli_reference.md b/docs/cli_reference.md index 39293808f..11078de76 100644 --- a/docs/cli_reference.md +++ b/docs/cli_reference.md @@ -269,14 +269,16 @@ Specification for splitting micro-batches during training. Configuration for reward/advantage normalization. -| Parameter | Type | Default | Description | -| ---------------- | -------------- | --------- | ---------------------------------------------------------------------------------------------------------------- | -| `mean_level` | string \| None | `"batch"` | Mean level for normalization. None for no mean normalization. **Choices:** `batch`, `group`, `None` | -| `mean_leave1out` | boolean | `False` | Whether to use leave-one-out average. | -| `std_level` | string \| None | `"batch"` | Standard deviation level for normalization. None for no std normalization. **Choices:** `batch`, `group`, `None` | -| `std_unbiased` | boolean | `True` | Whether to use unbiased standard deviation computation. Defaults to True (changed from False in v0.3.4). | -| `eps` | float | `1e-05` | The eps when dividing by standard deviation to avoid numerical issues. | -| `group_size` | integer | `1` | Group size for group-level normalization | +| Parameter | Type | Default | Description | +| ------------------ | -------------- | ---------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `mean_level` | string \| None | `"batch"` | Mean level for normalization. None for no mean normalization. **Choices:** `batch`, `group`, `None` | +| `mean_leave1out` | boolean | `False` | Whether to use leave-one-out average. | +| `std_level` | string \| None | `"batch"` | Standard deviation level for normalization. None for no std normalization. **Choices:** `batch`, `group`, `None` | +| `std_unbiased` | boolean | `True` | Whether to use unbiased standard deviation computation. Defaults to True (changed from False in v0.3.4). | +| `eps` | float | `1e-05` | The eps when dividing by standard deviation to avoid numerical issues. | +| `group_size` | integer | `1` | Group size for group-level normalization | +| `adv_norm_mode` | string | `"native"` | native or mix. native is the normal z-score normalization. for mix, the normal z-score and mean-base z-score normalization will be calculated and aggregated (more info please refer to the paper of MAPO). | +| `reward_norm_mode` | string | `"native"` | Mode for reward normalization. Currently only 'native' is supported. | (section-optimizer)= diff --git a/docs/figures/MAPO.jpg b/docs/figures/MAPO.jpg new file mode 100644 index 000000000..c6acf9438 Binary files /dev/null and b/docs/figures/MAPO.jpg differ diff --git a/examples/experimental/dr.grpo/gsm8k_drgrpo.yaml b/examples/experimental/dr.grpo/gsm8k_drgrpo.yaml index 9e90b9a6a..276e09230 100644 --- a/examples/experimental/dr.grpo/gsm8k_drgrpo.yaml +++ b/examples/experimental/dr.grpo/gsm8k_drgrpo.yaml @@ -1,4 +1,4 @@ -experiment_name: gsm8k-grpo +experiment_name: gsm8k-drgrpo trial_name: trial0 seed: 1 diff --git a/examples/experimental/lite_ppo/gsm8k_liteppo.yaml b/examples/experimental/lite_ppo/gsm8k_liteppo.yaml index 7bae82c94..be1966cb3 100644 --- a/examples/experimental/lite_ppo/gsm8k_liteppo.yaml +++ b/examples/experimental/lite_ppo/gsm8k_liteppo.yaml @@ -1,4 +1,4 @@ -experiment_name: gsm8k-grpo +experiment_name: gsm8k-liteppo trial_name: trial0 seed: 1 diff --git a/examples/experimental/mapo/gsm8k_mapo.py b/examples/experimental/mapo/gsm8k_mapo.py new file mode 100644 index 000000000..967fbd68a --- /dev/null +++ b/examples/experimental/mapo/gsm8k_mapo.py @@ -0,0 +1,307 @@ +import os +import sys +from copy import deepcopy + +import torch.distributed as dist +from torchdata.stateful_dataloader import StatefulDataLoader + +from areal.api.alloc_mode import AllocationMode +from areal.api.cli_args import GRPOConfig, load_expr_config +from areal.api.io_struct import FinetuneSpec, StepInfo, WeightUpdateMeta +from areal.dataset import get_custom_dataset +from areal.engine.ppo.actor import FSDPPPOActor +from areal.engine.sglang_remote import RemoteSGLangEngine +from areal.platforms import current_platform +from areal.utils import seeding, stats_tracker +from areal.utils.data import ( + broadcast_tensor_container, + cycle_dataloader, + tensor_container_to, +) +from areal.utils.device import log_gpu_stats +from areal.utils.evaluator import Evaluator +from areal.utils.hf_utils import load_hf_tokenizer +from areal.utils.recover import RecoverHandler +from areal.utils.saver import Saver +from areal.utils.stats_logger import StatsLogger +from areal.workflow.rlvr import RLVRWorkflow + + +def gsm8k_reward_fn(prompt, completions, prompt_ids, completion_ids, answer, **kwargs): + from areal.reward.math_parser import process_results + + return int(process_results(completions, answer)[0]) + + +def main(args): + config, _ = load_expr_config(args, GRPOConfig) + config: GRPOConfig + + rank = int(os.getenv("RANK")) + tokenizer = load_hf_tokenizer(config.tokenizer_path) + + seeding.set_random_seed(config.seed, key=f"trainer{rank}") + allocation_mode = AllocationMode.from_str(config.allocation_mode) + parallel_strategy = allocation_mode.train + assert parallel_strategy is not None + + # Initialize train engine + actor = FSDPPPOActor(config=config.actor) + actor.create_process_group(parallel_strategy=parallel_strategy) + + train_dataset = get_custom_dataset( + path=config.train_dataset.path, + rank=actor.data_parallel_rank, + world_size=actor.data_parallel_world_size, + split="train", + max_length=config.train_dataset.max_length, + type=config.train_dataset.type, + tokenizer=tokenizer, + ) + valid_dataset = get_custom_dataset( + path=config.valid_dataset.path, + rank=actor.data_parallel_rank, + world_size=actor.data_parallel_world_size, + split="test", + max_length=config.valid_dataset.max_length, + type=config.valid_dataset.type, + tokenizer=tokenizer, + ) + + # Create dataset and dataloaders + train_dataloader = StatefulDataLoader( + train_dataset, + batch_size=config.train_dataset.batch_size // actor.data_parallel_world_size, + shuffle=config.train_dataset.shuffle, + num_workers=config.train_dataset.num_workers, + collate_fn=lambda x: x, + drop_last=config.train_dataset.drop_last, + ) + valid_dataloader = StatefulDataLoader( + valid_dataset, + batch_size=config.valid_dataset.batch_size // actor.data_parallel_world_size, + shuffle=config.valid_dataset.shuffle, + num_workers=config.valid_dataset.num_workers, + collate_fn=lambda x: x, + drop_last=config.valid_dataset.drop_last, + ) + ft_spec = FinetuneSpec( + total_train_epochs=config.total_train_epochs, + dataset_size=len(train_dataloader) * config.train_dataset.batch_size, + train_batch_size=config.train_dataset.batch_size, + ) + + # Initialize inference engine + rollout = RemoteSGLangEngine(config.rollout) + rollout.initialize(train_data_parallel_size=parallel_strategy.dp_size) + eval_rollout = RemoteSGLangEngine(deepcopy(config.rollout)) + # NOTE: eval does not have any offpolicyness control + eval_rollout.config.max_head_offpolicyness = int(1e12) + eval_rollout.initialize() + + actor.initialize(None, ft_spec) + ref = None + if config.actor.kl_ctl > 0 and config.ref is not None: + ref = FSDPPPOActor(config=config.ref) + ref.create_process_group(parallel_strategy=parallel_strategy) + ref.initialize(None, ft_spec) + + # NOTE: Weight update meta only requires address and free port of rank 0, + # but `WeightUpdateMeta.from_fsdp_xccl` has to be executed on all ranks + # due to `engine.get_param_specs()`. + # Therefore, we create weight update meta on all ranks, then broadcast the one on rank 0. + weight_update_meta = [ + WeightUpdateMeta.from_fsdp_xccl( + AllocationMode.from_str(config.allocation_mode), actor + ) + ] + dist.broadcast_object_list(weight_update_meta, src=0) + weight_update_meta = weight_update_meta[0] + + # Create rollout workflow + if tokenizer.pad_token_id not in config.gconfig.stop_token_ids: + config.gconfig.stop_token_ids.append(tokenizer.pad_token_id) + if tokenizer.eos_token_id not in config.gconfig.stop_token_ids: + config.gconfig.stop_token_ids.append(tokenizer.eos_token_id) + workflow = RLVRWorkflow( + reward_fn=gsm8k_reward_fn, + gconfig=config.gconfig, + tokenizer=tokenizer, + enable_thinking=False, + dump_dir=os.path.join( + StatsLogger.get_log_path(config.stats_logger), "generated" + ), + ) + eval_workflow = RLVRWorkflow( + reward_fn=gsm8k_reward_fn, + gconfig=config.gconfig.new(temperature=0.6), + tokenizer=tokenizer, + enable_thinking=False, + rollout_stat_scope="eval-rollout", + dump_dir=os.path.join( + StatsLogger.get_log_path(config.stats_logger), "generated-eval" + ), + ) + + # Run training. + saver = Saver(config.saver, ft_spec) + stats_logger = StatsLogger(config.stats_logger, ft_spec) + evaluator = Evaluator(config.evaluator, ft_spec) + + recover_handler = RecoverHandler(config.recover, ft_spec) + recover_info = recover_handler.load( + actor, + saver, + evaluator, + stats_logger, + train_dataloader, + inference_engine=rollout, + weight_update_meta=weight_update_meta, + ) + start_step = ( + recover_info.last_step_info.next().global_step + if recover_info is not None + else 0 + ) + + total_epochs = config.total_train_epochs + steps_per_epoch = len(train_dataloader) + max_steps = total_epochs * steps_per_epoch + + data_generator = cycle_dataloader(train_dataloader) + for global_step in range(start_step, max_steps): + epoch = global_step // steps_per_epoch + step = global_step % steps_per_epoch + step_info = StepInfo( + global_step=global_step, + epoch=epoch, + epoch_step=step, + steps_per_epoch=steps_per_epoch, + ) + + with stats_tracker.record_timing("rollout"): + batch = None + if actor.is_data_parallel_head(): + if config.async_training: + batch = rollout.prepare_batch( + train_dataloader, + workflow=workflow, + should_accept=lambda sample: True, + ) + else: + batch = rollout.rollout_batch( + next(data_generator), + workflow=workflow, + should_accept=lambda sample: True, + ) + batch = tensor_container_to(batch, actor.device) + batch = broadcast_tensor_container( + batch, + src_rank=actor.current_data_parallel_head(), + group=actor.context_and_model_parallel_group, + ) + # Create barrier to synchronize all rollout processes. + dist.barrier(device_ids=[actor.device.index]) + current_platform.synchronize() + + if config.actor.recompute_logprob or config.actor.use_decoupled_loss: + with stats_tracker.record_timing("recompute_logp"): + logp = actor.compute_logp(batch) + batch["prox_logp"] = logp + log_gpu_stats("recompute logp") + + if ref is not None: + with stats_tracker.record_timing("ref_logp"): + batch["ref_logp"] = ref.compute_logp(batch) + log_gpu_stats("ref logp") + + with stats_tracker.record_timing("compute_advantage"): + actor.compute_advantages(batch) + log_gpu_stats("compute advantages") + + with ( + stats_tracker.record_timing("train_step"), + stats_tracker.scope("grpo_actor"), + ): + stats = actor.ppo_update(batch) + actor.step_lr_scheduler() + log_gpu_stats("ppo update") + + # pause inference for updating weights, save, and evaluation + rollout.pause() + + with stats_tracker.record_timing("update_weights"): + if dist.get_rank() == 0: + future = rollout.update_weights(weight_update_meta) + actor.upload_weights(weight_update_meta) + if dist.get_rank() == 0: + future.result() + dist.barrier(device_ids=[actor.device.index]) + current_platform.synchronize() + + actor.set_version(global_step + 1) + rollout.set_version(global_step + 1) + eval_rollout.set_version(global_step + 1) + + with stats_tracker.record_timing("save"): + saver.save(actor, epoch, step, global_step, tokenizer=tokenizer) + + with stats_tracker.record_timing("checkpoint_for_recover"): + recover_handler.dump( + actor, + step_info, + saver, + evaluator, + stats_logger, + train_dataloader, + tokenizer=tokenizer, + ) + + dist.barrier(device_ids=[actor.device.index]) + current_platform.synchronize() + + with stats_tracker.record_timing("eval"): + + def evaluate_fn(): + if actor.is_data_parallel_head(): + cnt = 0 + for data in valid_dataloader: + for item in data: + eval_rollout.submit(item, eval_workflow) + cnt += 1 + eval_rollout.wait(cnt, timeout=None) + dist.barrier(device_ids=[actor.device.index]) + current_platform.synchronize() + + evaluator.evaluate( + evaluate_fn, + epoch, + step, + global_step, + ) + + dist.barrier(device_ids=[actor.device.index]) + current_platform.synchronize() + + # Upload statistics to the logger (e.g., wandb) + stats[0].update( + stats_tracker.export_all(reduce_group=actor.data_parallel_group) + ) + stats_logger.commit(epoch, step, global_step, stats) + + dist.barrier(device_ids=[actor.device.index]) + current_platform.synchronize() + + # Resume rollout + rollout.resume() + + stats_logger.close() + eval_rollout.destroy() + rollout.destroy() + if ref is not None: + ref.destroy() + actor.destroy() + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/examples/experimental/mapo/gsm8k_mapo.yaml b/examples/experimental/mapo/gsm8k_mapo.yaml new file mode 100644 index 000000000..890639ca9 --- /dev/null +++ b/examples/experimental/mapo/gsm8k_mapo.yaml @@ -0,0 +1,151 @@ +experiment_name: gsm8k-mapo +trial_name: trial0 + +seed: 1 +total_train_epochs: 10 +tokenizer_path: ${actor.path} +async_training: true + +cluster: + n_nodes: 1 + n_gpus_per_node: 8 + fileroot: /tmp/areal/experiments + name_resolve: + type: nfs + nfs_record_root: /tmp/areal/name_resolve + +allocation_mode: sglang.d4p1t1+d4p1t1 + +rollout: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + max_concurrent_rollouts: 256 + queue_size: null + consumer_batch_size: ${train_dataset.batch_size} + max_head_offpolicyness: 2 + enable_rollout_tracing: false + +gconfig: + n_samples: 4 + min_new_tokens: 0 + max_new_tokens: 1024 + greedy: false + temperature: 1.0 + +actor: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: Qwen/Qwen2.5-1.5B-Instruct + init_from_scratch: false + disable_dropout: true + gradient_checkpointing: false + dtype: bfloat16 + mb_spec: + max_tokens_per_mb: 10240 + optimizer: + type: adam + lr: 1.70e-5 + weight_decay: 0.017 + beta1: 0.9 + beta2: 0.999 + eps: 1e-8 + lr_scheduler_type: constant + gradient_clipping: 1.0 + warmup_steps_proportion: 0.001 + backend: fsdp + group_size: ${gconfig.n_samples} + eps_clip: 0.4 + temperature: ${gconfig.temperature} + reward_scaling: 10.0 + reward_bias: -0.5 + kl_ctl: 0.0 + ppo_n_minibatches: 1 + recompute_logprob: true + use_decoupled_loss: true + behav_imp_weight_cap: 5.0 + dynamic_sampling: false + adv_norm: + mean_level: group + std_level: group + group_size: ${gconfig.n_samples} + adv_norm_mode: mix + max_new_tokens: ${gconfig.max_new_tokens} + +ref: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: ${actor.path} + init_from_scratch: false + disable_dropout: true + dtype: ${actor.dtype} + mb_spec: + max_tokens_per_mb: 10240 + optimizer: null + backend: fsdp + +# SGLang +sglang: + model_path: ${actor.path} + random_seed: ${seed} + skip_tokenizer_init: true + dtype: ${actor.dtype} + max_running_requests: null + context_length: 32768 + mem_fraction_static: 0.8 + +# datasets +train_dataset: + batch_size: 256 + shuffle: true + pin_memory: true + num_workers: 4 + path: openai/gsm8k + type: rl + max_length: 1024 + +valid_dataset: + batch_size: 256 + shuffle: true + pin_memory: true + num_workers: 4 + path: openai/gsm8k + type: rl + +# Utilities +saver: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: null + +recover: + mode: disabled + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: 3600 + +evaluator: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: null + +stats_logger: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + wandb: + mode: disabled + +launcher: + inference_server_cpus_per_gpu: 4 + inference_server_mem_per_gpu: 32768 + trainer_cpus_per_gpu: 4 + trainer_mem_per_gpu: 32768