From 393c2e7fbd769b1dcd536c6e69f6ef478195821f Mon Sep 17 00:00:00 2001 From: Ziyi <1034337098@qq.com> Date: Tue, 14 Oct 2025 00:06:00 -0900 Subject: [PATCH 01/24] . --- areal/api/alloc_mode.py | 9 ++ areal/engine/ppo/actor.py | 4 - areal/experimental/megatron_actor.py | 4 +- areal/utils/data.py | 90 ++++++++++++++++ areal/utils/functional.py | 6 +- examples/experimental/dapo/gsm8k_dapo.py | 116 +++++++++++++++------ examples/experimental/dapo/gsm8k_dapo.yaml | 4 +- 7 files changed, 194 insertions(+), 39 deletions(-) diff --git a/areal/api/alloc_mode.py b/areal/api/alloc_mode.py index b7f97693b..e2cda6f4f 100644 --- a/areal/api/alloc_mode.py +++ b/areal/api/alloc_mode.py @@ -839,6 +839,15 @@ def parse(self, expression: str): AllocationValidationError: When validation rules are violated ValueError: When parsing fails """ + # Check for common syntax errors and provide helpful messages + if "." in expression and ":" not in expression: + # User likely used dot notation instead of colon + raise ValueError( + f"Invalid allocation mode syntax: '{expression}'\n" + f"Please use colon ':' instead of dot '.' to separate backend and dimensions.\n" + f"Example: 'sglang:d4+fsdp:d4' instead of 'sglang.d4+fsdp.d4'" + ) + try: tree = self.parser.parse(expression) transformer = _ParallelStrategyTransformer() diff --git a/areal/engine/ppo/actor.py b/areal/engine/ppo/actor.py index 673f35a1a..c7e3c0e67 100644 --- a/areal/engine/ppo/actor.py +++ b/areal/engine/ppo/actor.py @@ -13,7 +13,6 @@ split_padded_tensor_dict_into_mb_list, ) from areal.utils.functional import ( - dynamic_sampling, gather_logprobs, gather_logprobs_entropy, ppo_actor_loss_fn, @@ -45,7 +44,6 @@ def __init__(self, config: PPOActorConfig, engine: TrainEngine): self.mask_no_eos_with_zero = config.mask_no_eos_with_zero self.temperature = config.temperature - self.dynamic_sampling = config.dynamic_sampling @torch.no_grad() def compute_logp( @@ -160,8 +158,6 @@ def compute_advantages(self, data: Dict[str, Any]) -> None: data["logprobs"] = old_logp def ppo_update(self, data: Dict[str, Any]) -> List[Dict[str, float]]: - if self.dynamic_sampling and len(data["rewards"]) % self.group_size == 0: - data, sampling_stat = dynamic_sampling(data, self.group_size) attn_mask = data["attention_mask"] loss_mask = data["loss_mask"] diff --git a/areal/experimental/megatron_actor.py b/areal/experimental/megatron_actor.py index ae8ebf2a2..a143c31c8 100644 --- a/areal/experimental/megatron_actor.py +++ b/areal/experimental/megatron_actor.py @@ -13,7 +13,7 @@ from areal.utils import stats_tracker from areal.utils.data import Normalization, split_padded_tensor_dict_into_mb_list from areal.utils.functional import ( - dynamic_sampling, + filter_batch, gather_logprobs, gather_logprobs_entropy, ppo_actor_loss_fn, @@ -179,7 +179,7 @@ def compute_advantages(self, data: Dict[str, Any]) -> None: def ppo_update(self, data: Dict[str, Any]) -> List[Dict[str, float]]: if self.dynamic_sampling and len(data["rewards"]) % self.group_size == 0: - data, sampling_stat = dynamic_sampling(data, self.group_size) + data, sampling_stat = filter_batch(data, self.group_size) attn_mask = data["attention_mask"] loss_mask = data["loss_mask"] diff --git a/areal/utils/data.py b/areal/utils/data.py index 8d905613f..0f2f66879 100644 --- a/areal/utils/data.py +++ b/areal/utils/data.py @@ -221,6 +221,96 @@ def concat_padded_tensors( return result +def aggregate_dicts( + dicts: List[Dict[str, Any]], pad_value: float = 0.0 +) -> Dict[str, Any]: + """Aggregate multiple dictionaries containing tensors and numeric values. + + This function handles different value types: + - Tensors: concatenated and padded to max length + - Numeric values: summed across dictionaries + - Other types: kept as lists + + Args: + dicts: List of dictionaries to aggregate + pad_value: Value to use for padding tensors + + Returns: + Aggregated dictionary with the same keys + """ + if not dicts: + return {} + + # Get all unique keys from all dictionaries + all_keys = set() + for d in dicts: + all_keys.update(d.keys()) + + result = {} + + for key in all_keys: + # Collect all values for this key + values = [d.get(key) for d in dicts if key in d] + + if not values: + continue + + # Check if all values are tensors + if all(torch.is_tensor(v) for v in values): + # For tensors, use concat_padded_tensors + # Create list of single-item dicts for concat_padded_tensors + tensor_dicts = [{key: v} for v in values] + aggregated = concat_padded_tensors(tensor_dicts, pad_value=pad_value) + result[key] = aggregated[key] + + # Check if all values are numeric (int, float) + elif all(isinstance(v, (int, float)) for v in values): + # Sum numeric values + result[key] = sum(values) + + else: + # For mixed or other types, keep as list + result[key] = values + + return result + + +def truncate_dict_to_batch_size( + data: Dict[str, Any], batch_size: int +) -> Dict[str, Any]: + """Truncate a dictionary containing tensors and numeric values to specified batch size. + + This function handles different value types: + - Tensors: take first batch_size elements along the first dimension + - Numeric values: keep as is (no truncation) + - Other types: keep as is (no truncation) + + Args: + data: Dictionary to truncate + batch_size: Target batch size for truncation + + Returns: + Truncated dictionary + """ + if not data: + return {} + + result = {} + + for key, value in data.items(): + if torch.is_tensor(value) and len(value.shape) > 0: + # For tensors, take first batch_size elements along first dimension + if value.shape[0] > batch_size: + result[key] = value[:batch_size] + else: + result[key] = value + else: + # For numeric values and other types, keep as is + result[key] = value + + return result + + def unpack_sequence( x: torch.Tensor, cu_seqlens: Optional[torch.Tensor] = None, diff --git a/areal/utils/functional.py b/areal/utils/functional.py index 11421b8ee..050ba373f 100644 --- a/areal/utils/functional.py +++ b/areal/utils/functional.py @@ -281,7 +281,11 @@ def ppo_critic_loss_fn( return value_loss, stat -def dynamic_sampling( +def filter_batch(filter_batch_fn, data: Dict[str, Any], group_size: int): + return filter_batch_fn(data, group_size) + + +def filter_batch_fn( data: Dict[str, Any], group_size: int ) -> Tuple[Dict[str, Any], Dict[str, int]]: """Filter samples by group when all rewards in a group are equal. diff --git a/examples/experimental/dapo/gsm8k_dapo.py b/examples/experimental/dapo/gsm8k_dapo.py index 317537338..8f1d30bb5 100644 --- a/examples/experimental/dapo/gsm8k_dapo.py +++ b/examples/experimental/dapo/gsm8k_dapo.py @@ -2,6 +2,7 @@ import sys from copy import deepcopy +import debugpy import torch.distributed as dist from torchdata.stateful_dataloader import StatefulDataLoader @@ -14,18 +15,24 @@ from areal.platforms import current_platform from areal.utils import seeding, stats_tracker from areal.utils.data import ( + aggregate_dicts, broadcast_tensor_container, cycle_dataloader, tensor_container_to, + truncate_dict_to_batch_size, ) from areal.utils.device import log_gpu_stats from areal.utils.evaluator import Evaluator +from areal.utils.functional import filter_batch from areal.utils.hf_utils import load_hf_tokenizer from areal.utils.recover import RecoverHandler from areal.utils.saver import Saver from areal.utils.stats_logger import StatsLogger from areal.workflow.rlvr import RLVRWorkflow +debugpy.listen(("0.0.0.0", 5678)) +debugpy.wait_for_client() + def gsm8k_reward_fn(prompt, completions, prompt_ids, completion_ids, answer, **kwargs): from areal.reward.math_parser import process_results @@ -69,9 +76,15 @@ def main(args): ) # Create dataset and dataloaders + train_dataloader_batch_size = ( + config.train_dataset.batch_size // actor.data_parallel_world_size + ) + valid_dataloader_batch_size = ( + config.valid_dataset.batch_size // actor.data_parallel_world_size + ) train_dataloader = StatefulDataLoader( train_dataset, - batch_size=config.train_dataset.batch_size // actor.data_parallel_world_size, + batch_size=train_dataloader_batch_size, shuffle=config.train_dataset.shuffle, num_workers=config.train_dataset.num_workers, collate_fn=lambda x: x, @@ -79,7 +92,7 @@ def main(args): ) valid_dataloader = StatefulDataLoader( valid_dataset, - batch_size=config.valid_dataset.batch_size // actor.data_parallel_world_size, + batch_size=valid_dataloader_batch_size, shuffle=config.valid_dataset.shuffle, num_workers=config.valid_dataset.num_workers, collate_fn=lambda x: x, @@ -170,31 +183,78 @@ def main(args): epoch_step=step, steps_per_epoch=steps_per_epoch, ) - - with stats_tracker.record_timing("rollout"): - batch = None - if actor.is_data_parallel_head(): - if config.async_training: - batch = rollout.prepare_batch( - train_dataloader, - workflow=workflow, - should_accept=lambda sample: True, - ) - else: - batch = rollout.rollout_batch( - next(data_generator), - workflow=workflow, - should_accept=lambda sample: True, + # Initialize batch collection + collected_batches = [] + while True: + with stats_tracker.record_timing("rollout"): + new_batch = None + if actor.is_data_parallel_head(): + if config.async_training: + new_batch = rollout.prepare_batch( + train_dataloader, + workflow=workflow, + should_accept=lambda sample: True, + ) + else: + new_batch = rollout.rollout_batch( + next(data_generator), + workflow=workflow, + should_accept=lambda sample: True, + ) + new_batch = tensor_container_to(new_batch, actor.device) + new_batch = broadcast_tensor_container( + new_batch, + src_rank=actor.current_data_parallel_head(), + group=actor.context_and_model_parallel_group, + ) + + # Create barrier to synchronize all rollout processes. + dist.barrier(device_ids=[actor.device.index]) + current_platform.synchronize() + + with stats_tracker.record_timing("compute_advantage"): + actor.compute_advantages(new_batch) + log_gpu_stats("compute advantages") + + # Collect the batch and process it immediately + if config.actor.dynamic_sampling: + with stats_tracker.record_timing("rollout_refill_dapo_batch_buffers"): + # Filter the current batch by groups + filtered_batch, sampling_stat = filter_batch( + new_batch, config.actor.group_size ) - batch = tensor_container_to(batch, actor.device) - batch = broadcast_tensor_container( - batch, - src_rank=actor.current_data_parallel_head(), - group=actor.context_and_model_parallel_group, - ) - # Create barrier to synchronize all rollout processes. - dist.barrier(device_ids=[actor.device.index]) - current_platform.synchronize() + + # Add filtered batch to collection + collected_batches.append(filtered_batch) + + # Aggregate all collected batches + aggregated_batch = aggregate_dicts(collected_batches) + + # Check if we have collected enough samples + if ( + len(aggregated_batch["rewards"]) + >= config.train_dataset.batch_size + ): + # Log the sampling statistics + stats_logger.commit( + epoch=epoch, + step=step, + global_step=global_step, + data=sampling_stat, + ) + + # Truncate batch to train_batch_size + batch = truncate_dict_to_batch_size( + aggregated_batch, config.train_dataset.batch_size + ) + break + else: + # Continue collecting more samples + batch = aggregated_batch + else: + # For non-dynamic sampling, just use the current batch + batch = new_batch + break if config.actor.recompute_logprob or config.actor.use_decoupled_loss: with stats_tracker.record_timing("recompute_logp"): @@ -207,10 +267,6 @@ def main(args): batch["ref_logp"] = ref.compute_logp(batch) log_gpu_stats("ref logp") - with stats_tracker.record_timing("compute_advantage"): - actor.compute_advantages(batch) - log_gpu_stats("compute advantages") - with ( stats_tracker.record_timing("train_step"), stats_tracker.scope("grpo_actor"), diff --git a/examples/experimental/dapo/gsm8k_dapo.yaml b/examples/experimental/dapo/gsm8k_dapo.yaml index 8b706563b..613474711 100644 --- a/examples/experimental/dapo/gsm8k_dapo.yaml +++ b/examples/experimental/dapo/gsm8k_dapo.yaml @@ -4,7 +4,7 @@ trial_name: trial0 seed: 1 total_train_epochs: 10 tokenizer_path: ${actor.path} -async_training: true +async_training: false cluster: n_nodes: 1 @@ -14,7 +14,7 @@ cluster: type: nfs nfs_record_root: /tmp/areal/name_resolve -allocation_mode: sglang.d4+fsdp.d4 +allocation_mode: sglang:d2+fsdp:d2 rollout: experiment_name: ${experiment_name} From 889ceb791f2c87bdcd3c9887bb972f774423ae3c Mon Sep 17 00:00:00 2001 From: Ziyi <1034337098@qq.com> Date: Tue, 14 Oct 2025 00:39:38 -0900 Subject: [PATCH 02/24] . --- areal/api/alloc_mode.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/areal/api/alloc_mode.py b/areal/api/alloc_mode.py index e2cda6f4f..b452111e3 100644 --- a/areal/api/alloc_mode.py +++ b/areal/api/alloc_mode.py @@ -839,14 +839,6 @@ def parse(self, expression: str): AllocationValidationError: When validation rules are violated ValueError: When parsing fails """ - # Check for common syntax errors and provide helpful messages - if "." in expression and ":" not in expression: - # User likely used dot notation instead of colon - raise ValueError( - f"Invalid allocation mode syntax: '{expression}'\n" - f"Please use colon ':' instead of dot '.' to separate backend and dimensions.\n" - f"Example: 'sglang:d4+fsdp:d4' instead of 'sglang.d4+fsdp.d4'" - ) try: tree = self.parser.parse(expression) From 3d3bfb631aae1f4c4670d6fa5671e2938e42a05b Mon Sep 17 00:00:00 2001 From: Ziyi <1034337098@qq.com> Date: Tue, 14 Oct 2025 03:03:10 -0900 Subject: [PATCH 03/24] . --- areal/utils/data.py | 12 +----- areal/utils/functional.py | 2 +- examples/experimental/dapo/gsm8k_dapo.py | 49 ++++++++++++++---------- 3 files changed, 31 insertions(+), 32 deletions(-) diff --git a/areal/utils/data.py b/areal/utils/data.py index 0f2f66879..d5e671775 100644 --- a/areal/utils/data.py +++ b/areal/utils/data.py @@ -221,7 +221,7 @@ def concat_padded_tensors( return result -def aggregate_dicts( +def aggregate_metric_dicts( dicts: List[Dict[str, Any]], pad_value: float = 0.0 ) -> Dict[str, Any]: """Aggregate multiple dictionaries containing tensors and numeric values. @@ -255,16 +255,8 @@ def aggregate_dicts( if not values: continue - # Check if all values are tensors - if all(torch.is_tensor(v) for v in values): - # For tensors, use concat_padded_tensors - # Create list of single-item dicts for concat_padded_tensors - tensor_dicts = [{key: v} for v in values] - aggregated = concat_padded_tensors(tensor_dicts, pad_value=pad_value) - result[key] = aggregated[key] - # Check if all values are numeric (int, float) - elif all(isinstance(v, (int, float)) for v in values): + if all(isinstance(v, (int, float)) for v in values): # Sum numeric values result[key] = sum(values) diff --git a/areal/utils/functional.py b/areal/utils/functional.py index 050ba373f..152985dca 100644 --- a/areal/utils/functional.py +++ b/areal/utils/functional.py @@ -285,7 +285,7 @@ def filter_batch(filter_batch_fn, data: Dict[str, Any], group_size: int): return filter_batch_fn(data, group_size) -def filter_batch_fn( +def filter_batch_fn_DAPO( data: Dict[str, Any], group_size: int ) -> Tuple[Dict[str, Any], Dict[str, int]]: """Filter samples by group when all rewards in a group are equal. diff --git a/examples/experimental/dapo/gsm8k_dapo.py b/examples/experimental/dapo/gsm8k_dapo.py index 8f1d30bb5..4a2adbb2a 100644 --- a/examples/experimental/dapo/gsm8k_dapo.py +++ b/examples/experimental/dapo/gsm8k_dapo.py @@ -2,7 +2,6 @@ import sys from copy import deepcopy -import debugpy import torch.distributed as dist from torchdata.stateful_dataloader import StatefulDataLoader @@ -15,24 +14,22 @@ from areal.platforms import current_platform from areal.utils import seeding, stats_tracker from areal.utils.data import ( - aggregate_dicts, + aggregate_metric_dicts, broadcast_tensor_container, + concat_padded_tensors, cycle_dataloader, tensor_container_to, truncate_dict_to_batch_size, ) from areal.utils.device import log_gpu_stats from areal.utils.evaluator import Evaluator -from areal.utils.functional import filter_batch +from areal.utils.functional import filter_batch, filter_batch_fn_DAPO from areal.utils.hf_utils import load_hf_tokenizer from areal.utils.recover import RecoverHandler from areal.utils.saver import Saver from areal.utils.stats_logger import StatsLogger from areal.workflow.rlvr import RLVRWorkflow -debugpy.listen(("0.0.0.0", 5678)) -debugpy.wait_for_client() - def gsm8k_reward_fn(prompt, completions, prompt_ids, completion_ids, answer, **kwargs): from areal.reward.math_parser import process_results @@ -184,7 +181,7 @@ def main(args): steps_per_epoch=steps_per_epoch, ) # Initialize batch collection - collected_batches = [] + collected_batches, sampling_stats = [], [] while True: with stats_tracker.record_timing("rollout"): new_batch = None @@ -220,37 +217,47 @@ def main(args): if config.actor.dynamic_sampling: with stats_tracker.record_timing("rollout_refill_dapo_batch_buffers"): # Filter the current batch by groups + # at least one group will be kept, don't worry for infinite loop + # max try: train_dataloader_batch_size(batch_size)/1(min kept)=train_dataloader_batch_size filtered_batch, sampling_stat = filter_batch( - new_batch, config.actor.group_size + filter_batch_fn_DAPO, new_batch, config.actor.group_size ) + sampling_stats.append(sampling_stat) # Add filtered batch to collection collected_batches.append(filtered_batch) - # Aggregate all collected batches - aggregated_batch = aggregate_dicts(collected_batches) - + # Aggregate all filter/clean batches + aggregated_batch = concat_padded_tensors(collected_batches) + total_batch_size = len(new_batch["rewards"]) + # just for sanity check + assert ( + total_batch_size + == train_dataloader_batch_size * config.actor.group_size + ) # Check if we have collected enough samples - if ( - len(aggregated_batch["rewards"]) - >= config.train_dataset.batch_size - ): + if len(aggregated_batch["rewards"]) >= total_batch_size: + sampling_stats = aggregate_metric_dicts(sampling_stats) + keep_ratio = float(sampling_stats["n_group_kept"]) / float( + sampling_stats["n_group_filtered"] + + sampling_stats["n_group_kept"] + ) + sampling_stats = { + "keep_ratio": keep_ratio, + "filer_ratio": 1.0 - keep_ratio, + } # Log the sampling statistics stats_logger.commit( epoch=epoch, step=step, global_step=global_step, - data=sampling_stat, + data=sampling_stats, ) - # Truncate batch to train_batch_size batch = truncate_dict_to_batch_size( - aggregated_batch, config.train_dataset.batch_size + data=aggregated_batch, batch_size=total_batch_size ) break - else: - # Continue collecting more samples - batch = aggregated_batch else: # For non-dynamic sampling, just use the current batch batch = new_batch From af5a93005f5a25fce72851b0a3947704e247319c Mon Sep 17 00:00:00 2001 From: Ziyi <1034337098@qq.com> Date: Tue, 14 Oct 2025 03:04:44 -0900 Subject: [PATCH 04/24] . --- examples/experimental/dapo/gsm8k_dapo.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/experimental/dapo/gsm8k_dapo.yaml b/examples/experimental/dapo/gsm8k_dapo.yaml index 613474711..c849f0f89 100644 --- a/examples/experimental/dapo/gsm8k_dapo.yaml +++ b/examples/experimental/dapo/gsm8k_dapo.yaml @@ -4,7 +4,7 @@ trial_name: trial0 seed: 1 total_train_epochs: 10 tokenizer_path: ${actor.path} -async_training: false +async_training: true cluster: n_nodes: 1 @@ -14,7 +14,7 @@ cluster: type: nfs nfs_record_root: /tmp/areal/name_resolve -allocation_mode: sglang:d2+fsdp:d2 +allocation_mode: sglang:d4+fsdp:d4 rollout: experiment_name: ${experiment_name} From 44532432603c73a86d8398f3f472c79e0c638f6d Mon Sep 17 00:00:00 2001 From: ZIYI ZENG <1034337098@qq.com> Date: Tue, 14 Oct 2025 20:11:13 +0800 Subject: [PATCH 05/24] Update areal/experimental/megatron_actor.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- areal/experimental/megatron_actor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/areal/experimental/megatron_actor.py b/areal/experimental/megatron_actor.py index a143c31c8..e2096f2a8 100644 --- a/areal/experimental/megatron_actor.py +++ b/areal/experimental/megatron_actor.py @@ -179,7 +179,7 @@ def compute_advantages(self, data: Dict[str, Any]) -> None: def ppo_update(self, data: Dict[str, Any]) -> List[Dict[str, float]]: if self.dynamic_sampling and len(data["rewards"]) % self.group_size == 0: - data, sampling_stat = filter_batch(data, self.group_size) + data, sampling_stat = filter_batch(filter_batch_fn_DAPO, data, self.group_size) attn_mask = data["attention_mask"] loss_mask = data["loss_mask"] From 841d16a9994ed6c125e06f14b544046a49497414 Mon Sep 17 00:00:00 2001 From: Ziyi <1034337098@qq.com> Date: Tue, 14 Oct 2025 03:14:58 -0900 Subject: [PATCH 06/24] . --- areal/utils/functional.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/areal/utils/functional.py b/areal/utils/functional.py index 152985dca..3737c016e 100644 --- a/areal/utils/functional.py +++ b/areal/utils/functional.py @@ -328,9 +328,10 @@ def filter_batch_fn_DAPO( # Expand the group mask to individual samples mask = valid_groups.repeat_interleave(group_size) - # In case all group is filtered out, return the original data (although not gradient in this case) + # In case all group is filtered out, only return only the first group to avoid infinite loop of dynamic sampling (although not gradient in this case) if not mask.any(): - return data, dict(n_group_kept=0, n_group_filtered=num_groups) + mask[:group_size] = True + valid_groups[0] = True n_group_kept = int(valid_groups.sum().item()) n_group_filtered = int(num_groups - n_group_kept) From f93869329bd7b46ab225c6163de6ab3b1e1346ad Mon Sep 17 00:00:00 2001 From: Ziyi <1034337098@qq.com> Date: Tue, 14 Oct 2025 03:16:39 -0900 Subject: [PATCH 07/24] . --- areal/utils/data.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/areal/utils/data.py b/areal/utils/data.py index d5e671775..efa24e4d0 100644 --- a/areal/utils/data.py +++ b/areal/utils/data.py @@ -224,16 +224,10 @@ def concat_padded_tensors( def aggregate_metric_dicts( dicts: List[Dict[str, Any]], pad_value: float = 0.0 ) -> Dict[str, Any]: - """Aggregate multiple dictionaries containing tensors and numeric values. - - This function handles different value types: - - Tensors: concatenated and padded to max length - - Numeric values: summed across dictionaries - - Other types: kept as lists + """Aggregate multiple dictionaries containing numeric values. Args: dicts: List of dictionaries to aggregate - pad_value: Value to use for padding tensors Returns: Aggregated dictionary with the same keys From 5b12bf92f23bbdd08872de7a9d69104c72b3de63 Mon Sep 17 00:00:00 2001 From: ZIYI ZENG <1034337098@qq.com> Date: Tue, 14 Oct 2025 20:27:59 +0800 Subject: [PATCH 08/24] Update areal/utils/data.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- areal/utils/data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/areal/utils/data.py b/areal/utils/data.py index efa24e4d0..130c03817 100644 --- a/areal/utils/data.py +++ b/areal/utils/data.py @@ -222,7 +222,7 @@ def concat_padded_tensors( def aggregate_metric_dicts( - dicts: List[Dict[str, Any]], pad_value: float = 0.0 + dicts: List[Dict[str, Any]] ) -> Dict[str, Any]: """Aggregate multiple dictionaries containing numeric values. From 8c2ddc88e0ed24a978d2bb9c0940155d89da3e35 Mon Sep 17 00:00:00 2001 From: ZIYI ZENG <1034337098@qq.com> Date: Tue, 14 Oct 2025 20:28:42 +0800 Subject: [PATCH 09/24] Update areal/utils/functional.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- areal/utils/functional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/areal/utils/functional.py b/areal/utils/functional.py index 3737c016e..4ff06005f 100644 --- a/areal/utils/functional.py +++ b/areal/utils/functional.py @@ -328,7 +328,7 @@ def filter_batch_fn_DAPO( # Expand the group mask to individual samples mask = valid_groups.repeat_interleave(group_size) - # In case all group is filtered out, only return only the first group to avoid infinite loop of dynamic sampling (although not gradient in this case) + # In case all groups are filtered out, keep the first group to prevent an infinite loop in the data collection process (though this group will not contribute to the gradient). if not mask.any(): mask[:group_size] = True valid_groups[0] = True From 44c70725c98bb08f4d07af52f0477c86488d5048 Mon Sep 17 00:00:00 2001 From: ZIYI ZENG <1034337098@qq.com> Date: Tue, 14 Oct 2025 20:29:08 +0800 Subject: [PATCH 10/24] Update examples/experimental/dapo/gsm8k_dapo.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- examples/experimental/dapo/gsm8k_dapo.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/experimental/dapo/gsm8k_dapo.py b/examples/experimental/dapo/gsm8k_dapo.py index 4a2adbb2a..14677335d 100644 --- a/examples/experimental/dapo/gsm8k_dapo.py +++ b/examples/experimental/dapo/gsm8k_dapo.py @@ -217,8 +217,8 @@ def main(args): if config.actor.dynamic_sampling: with stats_tracker.record_timing("rollout_refill_dapo_batch_buffers"): # Filter the current batch by groups - # at least one group will be kept, don't worry for infinite loop - # max try: train_dataloader_batch_size(batch_size)/1(min kept)=train_dataloader_batch_size + # In the worst-case scenario where each new batch contributes only one group (the minimum kept), + # it will take `train_dataloader_batch_size` retries to fill the target batch. filtered_batch, sampling_stat = filter_batch( filter_batch_fn_DAPO, new_batch, config.actor.group_size ) From 576e3427ca755206a9ca4e78debadd1fbbd7157d Mon Sep 17 00:00:00 2001 From: ZIYI ZENG <1034337098@qq.com> Date: Tue, 14 Oct 2025 20:29:30 +0800 Subject: [PATCH 11/24] Update examples/experimental/dapo/gsm8k_dapo.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- examples/experimental/dapo/gsm8k_dapo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/experimental/dapo/gsm8k_dapo.py b/examples/experimental/dapo/gsm8k_dapo.py index 14677335d..cd1cd719c 100644 --- a/examples/experimental/dapo/gsm8k_dapo.py +++ b/examples/experimental/dapo/gsm8k_dapo.py @@ -244,7 +244,7 @@ def main(args): ) sampling_stats = { "keep_ratio": keep_ratio, - "filer_ratio": 1.0 - keep_ratio, + "filter_ratio": 1.0 - keep_ratio, } # Log the sampling statistics stats_logger.commit( From 889a884538279987a05a33608dbd962ba8909a0f Mon Sep 17 00:00:00 2001 From: Ziyi <1034337098@qq.com> Date: Thu, 16 Oct 2025 03:47:57 -0900 Subject: [PATCH 12/24] . --- areal/api/workflow_api.py | 2 +- areal/utils/functional.py | 70 +------------ docs/algorithms/dapo.md | 4 + examples/experimental/dapo/gsm8k_dapo.py | 122 ++++++++--------------- 4 files changed, 49 insertions(+), 149 deletions(-) diff --git a/areal/api/workflow_api.py b/areal/api/workflow_api.py index 22d4facde..8d3b3d6c0 100644 --- a/areal/api/workflow_api.py +++ b/areal/api/workflow_api.py @@ -240,7 +240,7 @@ class _TimedResult: class _RolloutTaskInput: data: Dict[str, Any] workflow: RolloutWorkflow - should_accept: Callable | None = None + should_accept: Callable | Dict[str, Any] = None @dataclass diff --git a/areal/utils/functional.py b/areal/utils/functional.py index 4ff06005f..b5202f22b 100644 --- a/areal/utils/functional.py +++ b/areal/utils/functional.py @@ -1,5 +1,4 @@ import functools -import warnings from typing import Any, Dict, Optional, Tuple import numpy as np @@ -281,70 +280,11 @@ def ppo_critic_loss_fn( return value_loss, stat -def filter_batch(filter_batch_fn, data: Dict[str, Any], group_size: int): - return filter_batch_fn(data, group_size) - - -def filter_batch_fn_DAPO( - data: Dict[str, Any], group_size: int -) -> Tuple[Dict[str, Any], Dict[str, int]]: - """Filter samples by group when all rewards in a group are equal. - - Assumes samples of the same group are adjacent in the batch. - - Returns a new dict containing only kept samples (mask applied on batch dim - for all tensor values whose first dimension equals batch size), and a small - stats dict. - """ - rewards = data["rewards"] - if not torch.is_tensor(rewards): - raise TypeError("data['rewards'] must be a torch.Tensor") - batch_size = rewards.shape[0] - - if group_size <= 0: - warnings.warn("group_size <= 0; returning original data") - return data, dict(n_group_kept=0, n_group_filtered=0) - - if batch_size % group_size != 0: - warnings.warn( - "The group size is not divisible by the batch size. Return the original data" - ) - return data, dict( - n_group_kept=batch_size // max(group_size, 1), n_group_filtered=0 - ) - - # Calculate number of groups (must be divisible) - num_groups = batch_size // group_size - - # Reshape rewards to (num_groups, group_size) for group-wise operations - rewards_reshaped = rewards.view(num_groups, group_size) - - # Check if all elements in each group are equal to the first element - all_equal = (rewards_reshaped == rewards_reshaped[:, 0:1]).all(dim=1) - - # Create mask for groups to keep (where not all rewards are equal) - valid_groups = ~all_equal - - # Expand the group mask to individual samples - mask = valid_groups.repeat_interleave(group_size) - - # In case all groups are filtered out, keep the first group to prevent an infinite loop in the data collection process (though this group will not contribute to the gradient). - if not mask.any(): - mask[:group_size] = True - valid_groups[0] = True - - n_group_kept = int(valid_groups.sum().item()) - n_group_filtered = int(num_groups - n_group_kept) - - # Apply mask row-wise across tensors that share the same batch dimension - filtered: Dict[str, Any] = {} - for k, v in data.items(): - if torch.is_tensor(v) and v.shape[:1] == (batch_size,): - filtered[k] = v[mask] - else: - # keep untouched (e.g., scalars, metadata); caller should ensure consistency - filtered[k] = v - return filtered, dict(n_group_kept=n_group_kept, n_group_filtered=n_group_filtered) +def filter_batch_fn_DAPO_per_group(data: Dict[str, Any]) -> bool: + rewards_tensor = data["rewards"] + if rewards_tensor.shape[0] == 1: + raise ValueError("DAPO is base on group and requires batch size > 1") + return torch.all(rewards_tensor == rewards_tensor[0]) # code modified from VERL: https://github.com/volcengine/verl/blob/main/verl/workers/reward_manager/dapo.py diff --git a/docs/algorithms/dapo.md b/docs/algorithms/dapo.md index 1c31d954d..7b24f8065 100644 --- a/docs/algorithms/dapo.md +++ b/docs/algorithms/dapo.md @@ -45,6 +45,10 @@ We only list the different parameters from GRPO here: - `actor.eps_clip_higher`: The higher bound of clipping. - `actor.dynamic_sampling`: Define if dynamic sampling should be used. +### Dynamic Sampling Strategy +By default, a group will be filtered out if all tracjectorys in this group have the same reward. You can customize this by `filter_batch_fn_DAPO_per_group` within the `./areal/utils/functional.py` + + ### Overlong Penalty Here we briefly introduce the implementation details of DAPO. diff --git a/examples/experimental/dapo/gsm8k_dapo.py b/examples/experimental/dapo/gsm8k_dapo.py index bf36e9c84..f1eed1d94 100644 --- a/examples/experimental/dapo/gsm8k_dapo.py +++ b/examples/experimental/dapo/gsm8k_dapo.py @@ -1,3 +1,4 @@ +import dataclasses import os import sys from copy import deepcopy @@ -13,17 +14,14 @@ from areal.platforms import current_platform from areal.utils import seeding, stats_tracker from areal.utils.data import ( - aggregate_metric_dicts, broadcast_tensor_container, - concat_padded_tensors, cycle_dataloader, tensor_container_to, - truncate_dict_to_batch_size, ) from areal.utils.dataloader import create_dataloader from areal.utils.device import log_gpu_stats from areal.utils.evaluator import Evaluator -from areal.utils.functional import filter_batch, filter_batch_fn_DAPO +from areal.utils.functional import filter_batch_fn_DAPO_per_group from areal.utils.hf_utils import load_hf_tokenizer from areal.utils.recover import RecoverHandler from areal.utils.saver import Saver @@ -162,87 +160,45 @@ def main(args): steps_per_epoch=steps_per_epoch, ) # Initialize batch collection - collected_batches, sampling_stats = [], [] - while True: - with stats_tracker.record_timing("rollout"): - new_batch = None - if actor.is_data_parallel_head(): - if config.async_training: - new_batch = rollout.prepare_batch( - train_dataloader, - workflow=workflow, - should_accept=lambda sample: True, - ) - else: - new_batch = rollout.rollout_batch( - next(data_generator), - workflow=workflow, - should_accept=lambda sample: True, - ) - new_batch = tensor_container_to(new_batch, actor.device) - new_batch = broadcast_tensor_container( - new_batch, - src_rank=actor.current_data_parallel_head(), - group=actor.context_and_model_parallel_group, - ) - - # Create barrier to synchronize all rollout processes. - dist.barrier(device_ids=[actor.device.index]) - current_platform.synchronize() - - with stats_tracker.record_timing("compute_advantage"): - actor.compute_advantages(new_batch) - log_gpu_stats("compute advantages") - - # Collect the batch and process it immediately - if config.actor.dynamic_sampling: - with stats_tracker.record_timing("rollout_refill_dapo_batch_buffers"): - # Filter the current batch by groups - # In the worst-case scenario where each new batch contributes only one group (the minimum kept), - # it will take `train_dataloader_batch_size` retries to fill the target batch. - filtered_batch, sampling_stat = filter_batch( - filter_batch_fn_DAPO, new_batch, config.actor.group_size + + with stats_tracker.record_timing("rollout"): + if actor.is_data_parallel_head(): + if config.async_training: + batch = rollout.prepare_batch( + train_dataloader, + workflow=workflow, + should_accept=( + filter_batch_fn_DAPO_per_group + if config.actor.dynamic_sampling + else None + ), ) - sampling_stats.append(sampling_stat) - - # Add filtered batch to collection - collected_batches.append(filtered_batch) - - # Aggregate all filter/clean batches - aggregated_batch = concat_padded_tensors(collected_batches) - total_batch_size = len(new_batch["rewards"]) - # just for sanity check - assert ( - total_batch_size - == train_dataloader_batch_size * config.actor.group_size + else: + batch = rollout.rollout_batch( + next(data_generator), + workflow=workflow, + should_accept=( + filter_batch_fn_DAPO_per_group + if config.actor.dynamic_sampling + else None + ), ) - # Check if we have collected enough samples - if len(aggregated_batch["rewards"]) >= total_batch_size: - sampling_stats = aggregate_metric_dicts(sampling_stats) - keep_ratio = float(sampling_stats["n_group_kept"]) / float( - sampling_stats["n_group_filtered"] - + sampling_stats["n_group_kept"] - ) - sampling_stats = { - "keep_ratio": keep_ratio, - "filter_ratio": 1.0 - keep_ratio, - } - # Log the sampling statistics - stats_logger.commit( - epoch=epoch, - step=step, - global_step=global_step, - data=sampling_stats, - ) - # Truncate batch to train_batch_size - batch = truncate_dict_to_batch_size( - data=aggregated_batch, batch_size=total_batch_size - ) - break - else: - # For non-dynamic sampling, just use the current batch - batch = new_batch - break + new_batch = tensor_container_to(new_batch, actor.device) + batch = broadcast_tensor_container( + new_batch, + src_rank=actor.current_data_parallel_head(), + group=actor.context_and_model_parallel_group, + ) + + # Create barrier to synchronize all rollout processes. + dist.barrier(device_ids=[actor.device.index]) + current_platform.synchronize() + # record nums of group acceptance + stats_logger.commit(dataclasses.asdict(rollout.workflow_executor.rollout_stat)) + + with stats_tracker.record_timing("compute_advantage"): + actor.compute_advantages(new_batch) + log_gpu_stats("compute advantages") if config.actor.recompute_logprob or config.actor.use_decoupled_loss: with stats_tracker.record_timing("recompute_logp"): From b91dbb4eaae8286a42486c529318236ba453c61f Mon Sep 17 00:00:00 2001 From: Ziyi <1034337098@qq.com> Date: Sat, 18 Oct 2025 07:22:09 -0900 Subject: [PATCH 13/24] . --- areal/api/cli_args.py | 6 +- areal/utils/data.py | 40 ---------- areal/utils/stats_logger.py | 65 +++++++++++++++- docs/algorithms/dapo.md | 10 ++- docs/cli_reference.md | 98 ++++++++++++------------ examples/experimental/dapo/gsm8k_dapo.py | 61 +++++++-------- 6 files changed, 151 insertions(+), 129 deletions(-) diff --git a/areal/api/cli_args.py b/areal/api/cli_args.py index 6528c320c..081e17920 100644 --- a/areal/api/cli_args.py +++ b/areal/api/cli_args.py @@ -424,12 +424,10 @@ class PPOActorConfig(TrainEngineConfig): }, ) # Advanced Options - dynamic_sampling: bool = field( + dynamic_sampling: str = field( default=False, metadata={ - "help": "Enable dynamic sampling (within DAPO). If enabled, groups with the same reward will be masked out. " - "Note that enabling this option will lead to variable batch sizes. If you want to use a constant batch size with dynamic filtering, " - "you should use the `should_accept` parameter in `rollout_batch` and `prepare_batch`." + "help": "Dynamic sampling strategy ​(within DAPO). Select from none,dynamic and static. See the doc for more details" }, ) diff --git a/areal/utils/data.py b/areal/utils/data.py index 130c03817..cd4a8eb70 100644 --- a/areal/utils/data.py +++ b/areal/utils/data.py @@ -221,46 +221,6 @@ def concat_padded_tensors( return result -def aggregate_metric_dicts( - dicts: List[Dict[str, Any]] -) -> Dict[str, Any]: - """Aggregate multiple dictionaries containing numeric values. - - Args: - dicts: List of dictionaries to aggregate - - Returns: - Aggregated dictionary with the same keys - """ - if not dicts: - return {} - - # Get all unique keys from all dictionaries - all_keys = set() - for d in dicts: - all_keys.update(d.keys()) - - result = {} - - for key in all_keys: - # Collect all values for this key - values = [d.get(key) for d in dicts if key in d] - - if not values: - continue - - # Check if all values are numeric (int, float) - if all(isinstance(v, (int, float)) for v in values): - # Sum numeric values - result[key] = sum(values) - - else: - # For mixed or other types, keep as list - result[key] = values - - return result - - def truncate_dict_to_batch_size( data: Dict[str, Any], batch_size: int ) -> Dict[str, Any]: diff --git a/areal/utils/stats_logger.py b/areal/utils/stats_logger.py index bde19e762..6a12c24a5 100644 --- a/areal/utils/stats_logger.py +++ b/areal/utils/stats_logger.py @@ -1,7 +1,7 @@ import getpass import os import time -from typing import Dict, List +from typing import Any, Dict, List import swanlab import torch.distributed as dist @@ -147,3 +147,66 @@ def get_log_path(config: StatsLoggerConfig): path = f"{config.fileroot}/logs/{getpass.getuser()}/{config.experiment_name}/{config.trial_name}" os.makedirs(path, exist_ok=True) return path + + +def aggregate_metric_dicts(dicts: List[Dict[str, Any]]) -> Dict[str, Any]: + """Aggregate multiple dictionaries containing numeric values. + + Args: + dicts: List of dictionaries to aggregate + + Returns: + Aggregated dictionary with the same keys + """ + if not dicts: + return {} + + # Get all unique keys from all dictionaries + all_keys = set() + for d in dicts: + all_keys.update(d.keys()) + + result = {} + + for key in all_keys: + # Collect all values for this key + values = [d.get(key) for d in dicts if key in d] + + if not values: + continue + + # Check if all values are numeric (int, float) + if all(isinstance(v, (int, float)) for v in values): + # Sum numeric values + result[key] = sum(values) + + else: + # For mixed or other types, keep as list + result[key] = values + + return result + + +def log_sampling_stats( + sampling_stats_list: list, + epoch: int, + step: int, + global_step: int, + stats_logger: StatsLogger, +): + """Helper function to log sampling statistics for both static and dynamic sampling.""" + sampling_stats = aggregate_metric_dicts(sampling_stats_list) + keep_ratio = float(sampling_stats["n_group_kept"]) / float( + sampling_stats["n_group_filtered"] + sampling_stats["n_group_kept"] + ) + sampling_stats = { + "keep_ratio": keep_ratio, + "filter_ratio": 1.0 - keep_ratio, + } + # Log the sampling statistics + stats_logger.commit( + epoch=epoch, + step=step, + global_step=global_step, + data=sampling_stats, + ) diff --git a/docs/algorithms/dapo.md b/docs/algorithms/dapo.md index 1c31d954d..617c1dbbc 100644 --- a/docs/algorithms/dapo.md +++ b/docs/algorithms/dapo.md @@ -43,7 +43,7 @@ We only list the different parameters from GRPO here: - `actor.overlong_penalty_factor`: The factor of overlong penalty. - `actor.eps_clip`: The lower bound of clipping, default is `0.2`. - `actor.eps_clip_higher`: The higher bound of clipping. -- `actor.dynamic_sampling`: Define if dynamic sampling should be used. +- `actor.dynamic_sampling`: Define the dynamic sampling strategy, selected from `none`, `static` and `dynamic`. ### Overlong Penalty @@ -51,6 +51,14 @@ Here we briefly introduce the implementation details of DAPO. ![alt text](../figures/dapo_overlong_penalty.jpg) +### Dynamic Sampling Strategy + +- `none`: Turn off dynamic sampling. +- `static`: Only one rollout turn and apply the filter function on it, this may result in variable batch size. +- `dynamic`: Enable Multi-turn rollout to keep the *constant batch size*. + + + ## Example Usage > The algorithm is experimental and may not be stable. diff --git a/docs/cli_reference.md b/docs/cli_reference.md index adb2119fa..85c9bd019 100644 --- a/docs/cli_reference.md +++ b/docs/cli_reference.md @@ -308,55 +308,55 @@ Configuration for model optimization during training. Configuration for PPO actor model, a subclass of a TrainEngine. -| Parameter | Type | Default | Description | -| ------------------------- | ---------------------------------------------- | --------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `experiment_name` | string | **Required** | - | -| `trial_name` | string | **Required** | - | -| `path` | string | `""` | Path to HuggingFace checkpoint | -| `attn_impl` | string | `"flash_attention_2"` | Attention implementation for huggingface transformers model. **Choices:** `flash_attention_2` | -| `init_from_scratch` | boolean | `False` | Initialize model weights randomly | -| `is_critic` | boolean | `False` | Whether to use a critic/reward model | -| `mb_spec` | [`MicroBatchSpec`](section-micro-batch) | **Required** | - | -| `pad_to_maximum` | boolean | `False` | Whether to pad each microbatch to the length upper bound specified by mb_spec. Can reduce memory fragmentation but slows down training. | -| `disable_dropout` | boolean | `False` | Disable dropout layers during training | -| `gradient_checkpointing` | boolean | `True` | Enable gradient checkpointing | -| `dtype` | string | `"bfloat16"` | Parameter data type. | -| `grad_reduce_dtype` | string | `"float32"` | Gradient reduction data type. | -| `optimizer` | [`OptimizerConfig`](section-optimizer) \| None | `None` | Optimizer configuration. None means no training. | -| `weight_update_mode` | string | `"disk"` | - | -| `backend` | string | `""` | Training backend (refer to documentation) | -| `fsdp` | [`FSDPEngineConfig`](section-fsdp-engine) | **Required** | - | -| `use_lora` | boolean | `False` | Whether to use LoRA. Only support FSDP. Note that should be enabled together with vLLM/SGLang. | -| `lora_rank` | integer | `32` | lora rank | -| `lora_alpha` | integer | `16` | lora alpha | -| `target_modules` | list of string | **Required** | lora target_modules. | -| `peft_type` | string | `"lora"` | peft method type. Only LoRA is supported for now. | -| `group_size` | integer | `1` | Number of sequences in each group | -| `ppo_n_minibatches` | integer | `4` | Number of minibatches for each PPO update | -| `eps_clip` | float | `0.2` | Clipping factor for policy ratio | -| `eps_clip_higher` | float \| None | `None` | Clipping factor (higher value) for policy ratio. Default is None. When eps_clip_higher is set (decoupled), eps_clip will be used as the lower value. | -| `c_clip` | float \| None | `None` | Dual clipping factor for policy ratio, must be > 1.0. None disables dual clipping. | -| `temperature` | float | `1.0` | Temperature during generation. | -| `reward_norm` | [`NormConfig`](section-norm) \| None | `None` | Normalization configuration for rewards | -| `reward_scaling` | float | `1.0` | Reward scaling factor | -| `reward_bias` | float | `0.0` | Reward bias | -| `reward_clip` | float | `20.0` | Maximum absolute value for reward clipping | -| `overlong_reward_penalty` | boolean | `False` | Penalty for overlong sequences. Used within DAPO. | -| `overlong_tokens` | integer \| None | `None` | Number of tokens in the tail that will receive a penalty | -| `overlong_penalty_factor` | float \| None | `None` | Penalty factor for tokens in the tail | -| `mask_no_eos_with_zero` | boolean | `False` | Mask truncated generations (no EOS token) and exclude from training | -| `discount` | float | `1.0` | Discount factor for future rewards | -| `gae_lambda` | float | `1.0` | Lambda parameter for GAE | -| `adv_norm` | [`NormConfig`](section-norm) \| None | `None` | Normalization configuration for advantages. | -| `kl_ctl` | float | `0.1` | KL divergence coefficient | -| `kl_estimator` | string | `"k1"` | KL divergence estimator **Choices:** `k1`, `k2`, `k3` | -| `recompute_logprob` | boolean | `False` | Recompute log probability and replace the log probability returned by inference. | -| `use_decoupled_loss` | boolean | `False` | Use the decoupled loss. Implicitly enables recompute_logprob. | -| `behav_imp_weight_cap` | float \| None | `None` | Filter out tokens where behav_imp_weight exceeds behav_imp_weight_cap when computing loss. Must be > 1.0. use_decoupled_loss must be true. | -| `dynamic_sampling` | boolean | `False` | Enable dynamic sampling (within DAPO). If enabled, groups with the same reward will be masked out. Note that enabling this option will lead to variable batch sizes. If you want to use a constant batch size with dynamic filtering, you should use the `should_accept` parameter in `rollout_batch` and `prepare_batch`. | -| `log_agent_stats` | boolean | `False` | Log statistics for agent trajectories | -| `log_agent_stats_keys` | list of string | **Required** | Keys for logging agent trajectory statistics | -| `max_new_tokens` | integer | `1024` | Maximum number of new tokens to generate | +| Parameter | Type | Default | Description | +| ------------------------- | ---------------------------------------------- | --------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------- | +| `experiment_name` | string | **Required** | - | +| `trial_name` | string | **Required** | - | +| `path` | string | `""` | Path to HuggingFace checkpoint | +| `attn_impl` | string | `"flash_attention_2"` | Attention implementation for huggingface transformers model. **Choices:** `flash_attention_2` | +| `init_from_scratch` | boolean | `False` | Initialize model weights randomly | +| `is_critic` | boolean | `False` | Whether to use a critic/reward model | +| `mb_spec` | [`MicroBatchSpec`](section-micro-batch) | **Required** | - | +| `pad_to_maximum` | boolean | `False` | Whether to pad each microbatch to the length upper bound specified by mb_spec. Can reduce memory fragmentation but slows down training. | +| `disable_dropout` | boolean | `False` | Disable dropout layers during training | +| `gradient_checkpointing` | boolean | `True` | Enable gradient checkpointing | +| `dtype` | string | `"bfloat16"` | Parameter data type. | +| `grad_reduce_dtype` | string | `"float32"` | Gradient reduction data type. | +| `optimizer` | [`OptimizerConfig`](section-optimizer) \| None | `None` | Optimizer configuration. None means no training. | +| `weight_update_mode` | string | `"disk"` | - | +| `backend` | string | `""` | Training backend (refer to documentation) | +| `fsdp` | [`FSDPEngineConfig`](section-fsdp-engine) | **Required** | - | +| `use_lora` | boolean | `False` | Whether to use LoRA. Only support FSDP. Note that should be enabled together with vLLM/SGLang. | +| `lora_rank` | integer | `32` | lora rank | +| `lora_alpha` | integer | `16` | lora alpha | +| `target_modules` | list of string | **Required** | lora target_modules. | +| `peft_type` | string | `"lora"` | peft method type. Only LoRA is supported for now. | +| `group_size` | integer | `1` | Number of sequences in each group | +| `ppo_n_minibatches` | integer | `4` | Number of minibatches for each PPO update | +| `eps_clip` | float | `0.2` | Clipping factor for policy ratio | +| `eps_clip_higher` | float \| None | `None` | Clipping factor (higher value) for policy ratio. Default is None. When eps_clip_higher is set (decoupled), eps_clip will be used as the lower value. | +| `c_clip` | float \| None | `None` | Dual clipping factor for policy ratio, must be > 1.0. None disables dual clipping. | +| `temperature` | float | `1.0` | Temperature during generation. | +| `reward_norm` | [`NormConfig`](section-norm) \| None | `None` | Normalization configuration for rewards | +| `reward_scaling` | float | `1.0` | Reward scaling factor | +| `reward_bias` | float | `0.0` | Reward bias | +| `reward_clip` | float | `20.0` | Maximum absolute value for reward clipping | +| `overlong_reward_penalty` | boolean | `False` | Penalty for overlong sequences. Used within DAPO. | +| `overlong_tokens` | integer \| None | `None` | Number of tokens in the tail that will receive a penalty | +| `overlong_penalty_factor` | float \| None | `None` | Penalty factor for tokens in the tail | +| `mask_no_eos_with_zero` | boolean | `False` | Mask truncated generations (no EOS token) and exclude from training | +| `discount` | float | `1.0` | Discount factor for future rewards | +| `gae_lambda` | float | `1.0` | Lambda parameter for GAE | +| `adv_norm` | [`NormConfig`](section-norm) \| None | `None` | Normalization configuration for advantages. | +| `kl_ctl` | float | `0.1` | KL divergence coefficient | +| `kl_estimator` | string | `"k1"` | KL divergence estimator **Choices:** `k1`, `k2`, `k3` | +| `recompute_logprob` | boolean | `False` | Recompute log probability and replace the log probability returned by inference. | +| `use_decoupled_loss` | boolean | `False` | Use the decoupled loss. Implicitly enables recompute_logprob. | +| `behav_imp_weight_cap` | float \| None | `None` | Filter out tokens where behav_imp_weight exceeds behav_imp_weight_cap when computing loss. Must be > 1.0. use_decoupled_loss must be true. | +| `dynamic_sampling` | string | `False` | Dynamic sampling strategy ​(within DAPO). Select from none,dynamic and static. See the doc for more details | +| `log_agent_stats` | boolean | `False` | Log statistics for agent trajectories | +| `log_agent_stats_keys` | list of string | **Required** | Keys for logging agent trajectory statistics | +| `max_new_tokens` | integer | `1024` | Maximum number of new tokens to generate | (section-ppo-critic)= diff --git a/examples/experimental/dapo/gsm8k_dapo.py b/examples/experimental/dapo/gsm8k_dapo.py index cd1cd719c..ee53e3ca7 100644 --- a/examples/experimental/dapo/gsm8k_dapo.py +++ b/examples/experimental/dapo/gsm8k_dapo.py @@ -14,10 +14,10 @@ from areal.platforms import current_platform from areal.utils import seeding, stats_tracker from areal.utils.data import ( - aggregate_metric_dicts, broadcast_tensor_container, concat_padded_tensors, cycle_dataloader, + get_batch_size, tensor_container_to, truncate_dict_to_batch_size, ) @@ -27,7 +27,7 @@ from areal.utils.hf_utils import load_hf_tokenizer from areal.utils.recover import RecoverHandler from areal.utils.saver import Saver -from areal.utils.stats_logger import StatsLogger +from areal.utils.stats_logger import StatsLogger, log_sampling_stats from areal.workflow.rlvr import RLVRWorkflow @@ -41,6 +41,7 @@ def main(args): config, _ = load_expr_config(args, GRPOConfig) config: GRPOConfig + assert config.actor.dynamic_sampling in ["none", "static", "dynamic"] rank = int(os.getenv("RANK")) tokenizer = load_hf_tokenizer(config.tokenizer_path) @@ -171,6 +172,7 @@ def main(args): max_steps = total_epochs * steps_per_epoch data_generator = cycle_dataloader(train_dataloader) + for global_step in range(start_step, max_steps): epoch = global_step // steps_per_epoch step = global_step % steps_per_epoch @@ -214,48 +216,39 @@ def main(args): log_gpu_stats("compute advantages") # Collect the batch and process it immediately - if config.actor.dynamic_sampling: - with stats_tracker.record_timing("rollout_refill_dapo_batch_buffers"): - # Filter the current batch by groups - # In the worst-case scenario where each new batch contributes only one group (the minimum kept), - # it will take `train_dataloader_batch_size` retries to fill the target batch. - filtered_batch, sampling_stat = filter_batch( - filter_batch_fn_DAPO, new_batch, config.actor.group_size + if config.actor.dynamic_sampling in ["static", "dynamic"]: + # Filter the current batch by groups + filtered_batch, sampling_stat = filter_batch( + filter_batch_fn_DAPO, new_batch, config.actor.group_size + ) + sampling_stats.append(sampling_stat) + + if config.actor.dynamic_sampling == "static": + # Statistic sampling: No need to refill for static sampling, result in smaller(variant) batch size + batch = filtered_batch + # Log sampling statistics for static sampling + log_sampling_stats( + sampling_stats, epoch, step, global_step, stats_logger ) - sampling_stats.append(sampling_stat) - + break + else: + # Dynamic sampling: keep collecting batches until we reach the target batch size # Add filtered batch to collection collected_batches.append(filtered_batch) # Aggregate all filter/clean batches aggregated_batch = concat_padded_tensors(collected_batches) - total_batch_size = len(new_batch["rewards"]) - # just for sanity check - assert ( - total_batch_size - == train_dataloader_batch_size * config.actor.group_size - ) + expected_batch_size = get_batch_size(new_batch) + aggregated_batch_size = get_batch_size(aggregated_batch) # Check if we have collected enough samples - if len(aggregated_batch["rewards"]) >= total_batch_size: - sampling_stats = aggregate_metric_dicts(sampling_stats) - keep_ratio = float(sampling_stats["n_group_kept"]) / float( - sampling_stats["n_group_filtered"] - + sampling_stats["n_group_kept"] - ) - sampling_stats = { - "keep_ratio": keep_ratio, - "filter_ratio": 1.0 - keep_ratio, - } - # Log the sampling statistics - stats_logger.commit( - epoch=epoch, - step=step, - global_step=global_step, - data=sampling_stats, + if aggregated_batch_size >= expected_batch_size: + # Log sampling statistics for dynamic sampling + log_sampling_stats( + sampling_stats, epoch, step, global_step, stats_logger ) # Truncate batch to train_batch_size batch = truncate_dict_to_batch_size( - data=aggregated_batch, batch_size=total_batch_size + data=aggregated_batch, batch_size=expected_batch_size ) break else: From 56944c3455c04af489de4bfcc6acf40bfdbc11e6 Mon Sep 17 00:00:00 2001 From: Ziyi <1034337098@qq.com> Date: Sat, 18 Oct 2025 07:40:44 -0900 Subject: [PATCH 14/24] . --- areal/experimental/megatron_actor.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/areal/experimental/megatron_actor.py b/areal/experimental/megatron_actor.py index e2096f2a8..baec8de91 100644 --- a/areal/experimental/megatron_actor.py +++ b/areal/experimental/megatron_actor.py @@ -13,7 +13,6 @@ from areal.utils import stats_tracker from areal.utils.data import Normalization, split_padded_tensor_dict_into_mb_list from areal.utils.functional import ( - filter_batch, gather_logprobs, gather_logprobs_entropy, ppo_actor_loss_fn, @@ -178,9 +177,6 @@ def compute_advantages(self, data: Dict[str, Any]) -> None: data["logprobs"] = old_logp def ppo_update(self, data: Dict[str, Any]) -> List[Dict[str, float]]: - if self.dynamic_sampling and len(data["rewards"]) % self.group_size == 0: - data, sampling_stat = filter_batch(filter_batch_fn_DAPO, data, self.group_size) - attn_mask = data["attention_mask"] loss_mask = data["loss_mask"] reward_score = data["rewards"] From 249c4289c145e88337a5e6367a2163320a030940 Mon Sep 17 00:00:00 2001 From: ZIYI ZENG <1034337098@qq.com> Date: Sun, 19 Oct 2025 01:03:22 +0800 Subject: [PATCH 15/24] Update docs/cli_reference.md Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- docs/cli_reference.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/cli_reference.md b/docs/cli_reference.md index de8e5f4c4..1749ddcbf 100644 --- a/docs/cli_reference.md +++ b/docs/cli_reference.md @@ -356,7 +356,7 @@ Configuration for PPO actor model, a subclass of a TrainEngine. | `recompute_logprob` | boolean | `False` | Recompute log probability and replace the log probability returned by inference. | | `use_decoupled_loss` | boolean | `False` | Use the decoupled loss. Implicitly enables recompute_logprob. | | `behav_imp_weight_cap` | float \| None | `None` | Filter out tokens where behav_imp_weight exceeds behav_imp_weight_cap when computing loss. Must be > 1.0. use_decoupled_loss must be true. | -| `dynamic_sampling` | string | `False` | Dynamic sampling strategy ​(within DAPO). Select from none,dynamic and static. See the doc for more details | +| `dynamic_sampling` | string | `"none"` | Dynamic sampling strategy ​(within DAPO). Select from none,dynamic and static. See the doc for more details | | `log_agent_stats` | boolean | `False` | Log statistics for agent trajectories | | `log_agent_stats_keys` | list of string | **Required** | Keys for logging agent trajectory statistics | | `max_new_tokens` | integer | `1024` | Maximum number of new tokens to generate | From 60e4c82d92a00f5f9ab61c868c80d480642dc29c Mon Sep 17 00:00:00 2001 From: Ziyi <1034337098@qq.com> Date: Sat, 18 Oct 2025 08:04:56 -0900 Subject: [PATCH 16/24] . --- areal/api/cli_args.py | 2 +- docs/cli_reference.md | 2 +- examples/experimental/dapo/gsm8k_dapo.py | 2 -- examples/experimental/dapo/gsm8k_dapo.yaml | 2 +- 4 files changed, 3 insertions(+), 5 deletions(-) diff --git a/areal/api/cli_args.py b/areal/api/cli_args.py index c901543b9..428211158 100644 --- a/areal/api/cli_args.py +++ b/areal/api/cli_args.py @@ -487,7 +487,7 @@ class PPOActorConfig(TrainEngineConfig): ) # Advanced Options dynamic_sampling: str = field( - default=False, + default="none", metadata={ "help": "Dynamic sampling strategy ​(within DAPO). Select from none,dynamic and static. See the doc for more details" }, diff --git a/docs/cli_reference.md b/docs/cli_reference.md index de8e5f4c4..1749ddcbf 100644 --- a/docs/cli_reference.md +++ b/docs/cli_reference.md @@ -356,7 +356,7 @@ Configuration for PPO actor model, a subclass of a TrainEngine. | `recompute_logprob` | boolean | `False` | Recompute log probability and replace the log probability returned by inference. | | `use_decoupled_loss` | boolean | `False` | Use the decoupled loss. Implicitly enables recompute_logprob. | | `behav_imp_weight_cap` | float \| None | `None` | Filter out tokens where behav_imp_weight exceeds behav_imp_weight_cap when computing loss. Must be > 1.0. use_decoupled_loss must be true. | -| `dynamic_sampling` | string | `False` | Dynamic sampling strategy ​(within DAPO). Select from none,dynamic and static. See the doc for more details | +| `dynamic_sampling` | string | `"none"` | Dynamic sampling strategy ​(within DAPO). Select from none,dynamic and static. See the doc for more details | | `log_agent_stats` | boolean | `False` | Log statistics for agent trajectories | | `log_agent_stats_keys` | list of string | **Required** | Keys for logging agent trajectory statistics | | `max_new_tokens` | integer | `1024` | Maximum number of new tokens to generate | diff --git a/examples/experimental/dapo/gsm8k_dapo.py b/examples/experimental/dapo/gsm8k_dapo.py index 034b25aa9..dd7536f71 100644 --- a/examples/experimental/dapo/gsm8k_dapo.py +++ b/examples/experimental/dapo/gsm8k_dapo.py @@ -63,8 +63,6 @@ def main(args): ) # Create dataset and dataloaders - train_dataset.batch_size // actor.data_parallel_world_size - valid_dataset.batch_size // actor.data_parallel_world_size train_dataloader = create_dataloader( train_dataset, rank=actor.data_parallel_rank, diff --git a/examples/experimental/dapo/gsm8k_dapo.yaml b/examples/experimental/dapo/gsm8k_dapo.yaml index c849f0f89..102064617 100644 --- a/examples/experimental/dapo/gsm8k_dapo.yaml +++ b/examples/experimental/dapo/gsm8k_dapo.yaml @@ -67,7 +67,7 @@ actor: recompute_logprob: true use_decoupled_loss: true behav_imp_weight_cap: 5.0 - dynamic_sampling: true + dynamic_sampling: 'dynamic' reward_norm: mean_level: group std_level: group From 9aa2ec4c8df85772668c74e8ecae84893bd46834 Mon Sep 17 00:00:00 2001 From: ZIYI ZENG <1034337098@qq.com> Date: Sun, 19 Oct 2025 15:26:36 +0800 Subject: [PATCH 17/24] Update areal/api/cli_args.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- areal/api/cli_args.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/areal/api/cli_args.py b/areal/api/cli_args.py index 428211158..4cfded6f0 100644 --- a/areal/api/cli_args.py +++ b/areal/api/cli_args.py @@ -489,7 +489,7 @@ class PPOActorConfig(TrainEngineConfig): dynamic_sampling: str = field( default="none", metadata={ - "help": "Dynamic sampling strategy ​(within DAPO). Select from none,dynamic and static. See the doc for more details" + "help": "Dynamic sampling strategy (within DAPO). Select from `none`, `dynamic` and `static`. See the doc for more details" }, ) From bdf0ae57a1149cdf7491cb608d18880174794f0d Mon Sep 17 00:00:00 2001 From: ZIYI ZENG <1034337098@qq.com> Date: Sun, 19 Oct 2025 15:26:48 +0800 Subject: [PATCH 18/24] Update docs/cli_reference.md Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- docs/cli_reference.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/cli_reference.md b/docs/cli_reference.md index 1749ddcbf..beac11e33 100644 --- a/docs/cli_reference.md +++ b/docs/cli_reference.md @@ -356,7 +356,7 @@ Configuration for PPO actor model, a subclass of a TrainEngine. | `recompute_logprob` | boolean | `False` | Recompute log probability and replace the log probability returned by inference. | | `use_decoupled_loss` | boolean | `False` | Use the decoupled loss. Implicitly enables recompute_logprob. | | `behav_imp_weight_cap` | float \| None | `None` | Filter out tokens where behav_imp_weight exceeds behav_imp_weight_cap when computing loss. Must be > 1.0. use_decoupled_loss must be true. | -| `dynamic_sampling` | string | `"none"` | Dynamic sampling strategy ​(within DAPO). Select from none,dynamic and static. See the doc for more details | +| `dynamic_sampling` | string | `"none"` | Dynamic sampling strategy (within DAPO). Select from `none`, `dynamic` and `static`. See the doc for more details | | `log_agent_stats` | boolean | `False` | Log statistics for agent trajectories | | `log_agent_stats_keys` | list of string | **Required** | Keys for logging agent trajectory statistics | | `max_new_tokens` | integer | `1024` | Maximum number of new tokens to generate | From 2d0ffac98905aeaca6c84ab2f94dc0fc46f8c7fb Mon Sep 17 00:00:00 2001 From: Ziyi <1034337098@qq.com> Date: Thu, 23 Oct 2025 07:32:00 +0000 Subject: [PATCH 19/24] . --- examples/experimental/dapo/gsm8k_dapo.py | 140 +++++++++++------------ 1 file changed, 68 insertions(+), 72 deletions(-) diff --git a/examples/experimental/dapo/gsm8k_dapo.py b/examples/experimental/dapo/gsm8k_dapo.py index c7aae877f..b711d96ff 100644 --- a/examples/experimental/dapo/gsm8k_dapo.py +++ b/examples/experimental/dapo/gsm8k_dapo.py @@ -22,7 +22,6 @@ from areal.utils.dataloader import create_dataloader from areal.utils.device import log_gpu_stats from areal.utils.evaluator import Evaluator -from areal.utils.functional import filter_batch_fn_DAPO_per_group from areal.utils.hf_utils import load_hf_tokenizer from areal.utils.recover import RecoverHandler from areal.utils.saver import Saver @@ -163,84 +162,81 @@ def main(args): # Initialize batch collection with stats_tracker.record_timing("rollout"): - if actor.is_data_parallel_head(): - if config.async_training: - batch = rollout.prepare_batch( - train_dataloader, - workflow=workflow, - should_accept=( - filter_batch_fn_DAPO_per_group - if config.actor.dynamic_sampling - else None - ), - ) - else: - batch = rollout.rollout_batch( - next(data_generator), - workflow=workflow, - should_accept=( - filter_batch_fn_DAPO_per_group - if config.actor.dynamic_sampling - else None - ), - ) - new_batch = tensor_container_to(new_batch, actor.device) - batch = broadcast_tensor_container( - new_batch, - src_rank=actor.current_data_parallel_head(), - group=actor.context_and_model_parallel_group, + collected_batches = [] + while True: + if actor.is_data_parallel_head(): + if config.async_training: + new_batch = rollout.prepare_batch( + train_dataloader, + workflow=workflow, + # TODO: refactor API in future + should_accept=None, + ) + else: + new_batch = rollout.rollout_batch( + next(data_generator), + workflow=workflow, + should_accept=None, + ) + new_batch = tensor_container_to(new_batch, actor.device) + new_batch = broadcast_tensor_container( + new_batch, + src_rank=actor.current_data_parallel_head(), + group=actor.context_and_model_parallel_group, + ) + + # Create barrier to synchronize all rollout processes. + dist.barrier(device_ids=[actor.device.index]) + current_platform.synchronize() + # record nums of group acceptance + stats_logger.commit( + dataclasses.asdict(rollout.workflow_executor.rollout_stat) ) - # Create barrier to synchronize all rollout processes. - dist.barrier(device_ids=[actor.device.index]) - current_platform.synchronize() - # record nums of group acceptance - stats_logger.commit(dataclasses.asdict(rollout.workflow_executor.rollout_stat)) - - with stats_tracker.record_timing("compute_advantage"): - actor.compute_advantages(new_batch) - log_gpu_stats("compute advantages") - - # Collect the batch and process it immediately - if config.actor.dynamic_sampling in ["static", "dynamic"]: - # Filter the current batch by groups - filtered_batch, sampling_stat = filter_batch( - filter_batch_fn_DAPO, new_batch, config.actor.group_size - ) - sampling_stats.append(sampling_stat) - - if config.actor.dynamic_sampling == "static": - # Statistic sampling: No need to refill for static sampling, result in smaller(variant) batch size - batch = filtered_batch - # Log sampling statistics for static sampling - log_sampling_stats( - sampling_stats, epoch, step, global_step, stats_logger + with stats_tracker.record_timing("compute_advantage"): + actor.compute_advantages(new_batch) + log_gpu_stats("compute advantages") + + # Collect the batch and process it immediately + if config.actor.dynamic_sampling in ["static", "dynamic"]: + # Filter the current batch by groups + filtered_batch, sampling_stat = filter_batch( + filter_batch_fn_DAPO, new_batch, config.actor.group_size ) - break - else: - # Dynamic sampling: keep collecting batches until we reach the target batch size - # Add filtered batch to collection - collected_batches.append(filtered_batch) - - # Aggregate all filter/clean batches - aggregated_batch = concat_padded_tensors(collected_batches) - expected_batch_size = get_batch_size(new_batch) - aggregated_batch_size = get_batch_size(aggregated_batch) - # Check if we have collected enough samples - if aggregated_batch_size >= expected_batch_size: - # Log sampling statistics for dynamic sampling + sampling_stats.append(sampling_stat) + + if config.actor.dynamic_sampling == "static": + # Statistic sampling: No need to refill for static sampling, result in smaller(variant) batch size + batch = filtered_batch + # Log sampling statistics for static sampling log_sampling_stats( sampling_stats, epoch, step, global_step, stats_logger ) - # Truncate batch to train_batch_size - batch = truncate_dict_to_batch_size( - data=aggregated_batch, batch_size=expected_batch_size - ) break - else: - # For non-dynamic sampling, just use the current batch - batch = new_batch - break + elif config.actor.dynamic_sampling == "dynamic": + # Dynamic sampling: keep collecting batches until we reach the target batch size + # Add filtered batch to collection + collected_batches.append(filtered_batch) + + # Aggregate all filter/clean batches + aggregated_batch = concat_padded_tensors(collected_batches) + expected_batch_size = get_batch_size(new_batch) + aggregated_batch_size = get_batch_size(aggregated_batch) + # Check if we have collected enough samples + if aggregated_batch_size >= expected_batch_size: + # Log sampling statistics for dynamic sampling + log_sampling_stats( + sampling_stats, epoch, step, global_step, stats_logger + ) + # Truncate batch to train_batch_size + batch = truncate_dict_to_batch_size( + data=aggregated_batch, batch_size=expected_batch_size + ) + break + else: + # For non-dynamic sampling, just use the current batch + batch = new_batch + break if config.actor.recompute_logprob or config.actor.use_decoupled_loss: with stats_tracker.record_timing("recompute_logp"): From b2d70429582744105f99d5a8b6c662b5ba346b0e Mon Sep 17 00:00:00 2001 From: Ziyi <1034337098@qq.com> Date: Thu, 23 Oct 2025 22:22:32 -0900 Subject: [PATCH 20/24] . --- areal/api/cli_args.py | 68 ++++++----- docs/algorithms/dapo.md | 2 +- docs/cli_reference.md | 106 +++++++++--------- examples/experimental/dapo/gsm8k_dapo.py | 85 ++++++-------- examples/experimental/dapo/gsm8k_dapo.yaml | 6 +- .../experimental/dr.grpo/gsm8k_drgrpo.yaml | 2 +- .../experimental/lite_ppo/gsm8k_liteppo.yaml | 2 +- examples/lora/gsm8k_grpo_lora.yaml | 2 +- examples/math/boba_grpo_vllm.yaml | 2 +- examples/math/deprecated_gsm8k_grpo.yaml | 2 +- examples/math/gsm8k_grpo.yaml | 2 +- examples/math/gsm8k_ppo.yaml | 2 +- examples/math/gsm8k_reinforce.yaml | 2 +- examples/math/gsm8k_reinforce_baseline.yaml | 2 +- examples/math/gsm8k_rloo.yaml | 2 +- examples/multi-turn-math/config.yaml | 2 +- examples/tir/tir_math_config.yaml | 2 +- examples/vlm/clevr_count_70k_grpo.yaml | 2 +- recipe/AEnt/actor.py | 9 +- recipe/AEnt/configs/gsm8k_aent_grpo.yaml | 2 +- 20 files changed, 144 insertions(+), 160 deletions(-) diff --git a/areal/api/cli_args.py b/areal/api/cli_args.py index 4cfded6f0..12407bb10 100644 --- a/areal/api/cli_args.py +++ b/areal/api/cli_args.py @@ -3,14 +3,9 @@ import os from dataclasses import asdict, dataclass, field from pathlib import Path -from typing import Dict, List import uvloop import yaml - -from areal.utils.pkg_version import is_version_less - -uvloop.install() from hydra import compose as hydra_compose from hydra import initialize as hydra_init from hydra.core.global_hydra import GlobalHydra @@ -18,6 +13,9 @@ from areal.platforms import current_platform from areal.utils import name_resolve, pkg_version +from areal.utils.pkg_version import is_version_less + +uvloop.install() @dataclass @@ -129,11 +127,11 @@ class GenerationHyperparameters: default=1.0, metadata={"help": "Sampling temperature. Higher values increase diversity."}, ) - stop_token_ids: List[int] = field( + stop_token_ids: list[int] = field( default_factory=list, metadata={"help": "Stop generation when encountering these token IDs."}, ) - stop: List[str] | None = field( + stop: list[str] | None = field( default=None, metadata={ "help": "One or multiple stop words. Generation will stop if one of these words is sampled." @@ -232,7 +230,7 @@ class OptimizerConfig: class FSDPWrapPolicy: """Policy configuration for FSDP model layer wrapping. None defaults to wrapping transformer decoder layers defined by transformers.""" - transformer_layer_cls_to_wrap: List[str] | None = field( + transformer_layer_cls_to_wrap: list[str] | None = field( default=None, metadata={"help": "A list of transformer layer names for FSDP to wrap."}, ) @@ -310,7 +308,7 @@ class MegatronEngineConfig: recompute_method: str | None = "uniform" recompute_num_layers: int | None = 1 distribute_saved_activations: bool | None = None - recompute_modules: List[str] | None = None + recompute_modules: list[str] | None = None @dataclass @@ -378,7 +376,7 @@ class TrainEngineConfig: ) lora_rank: int = field(default=32, metadata={"help": "lora rank"}) lora_alpha: int = field(default=16, metadata={"help": "lora alpha"}) - target_modules: List[str] = field( + target_modules: list[str] = field( default_factory=list, metadata={"help": "lora target_modules."}, ) @@ -486,10 +484,10 @@ class PPOActorConfig(TrainEngineConfig): }, ) # Advanced Options - dynamic_sampling: str = field( + dynamic_sampling_strategy: str = field( default="none", metadata={ - "help": "Dynamic sampling strategy (within DAPO). Select from `none`, `dynamic` and `static`. See the doc for more details" + "help": "Dynamic sampling strategy. Select from `none`, `dynamic` and `static`. See the doc for more details" }, ) @@ -498,7 +496,7 @@ class PPOActorConfig(TrainEngineConfig): default=False, metadata={"help": "Log statistics for agent trajectories"}, ) - log_agent_stats_keys: List[str] = field( + log_agent_stats_keys: list[str] = field( default_factory=lambda: [], metadata={"help": "Keys for logging agent trajectory statistics"}, ) @@ -572,7 +570,7 @@ def build_args( port, dist_init_addr: str | None = None, ): - args: Dict = conf_as_dict(vllm_config) + args: dict = conf_as_dict(vllm_config) args = dict( host=host, port=port, @@ -606,11 +604,11 @@ def build_cmd( if v is None or v is False or v == "": continue if v is True: - flags.append(f"--{k.replace('_','-')}") + flags.append(f"--{k.replace('_', '-')}") elif isinstance(v, list): - flags.append(f"--{k.replace('_','-')} {' '.join(map(str, v))}") + flags.append(f"--{k.replace('_', '-')} {' '.join(map(str, v))}") else: - flags.append(f"--{k.replace('_','-')} {v}") + flags.append(f"--{k.replace('_', '-')} {v}") return f"python3 -m areal.thirdparty.vllm.areal_vllm_server {' '.join(flags)}" @@ -636,7 +634,7 @@ class SGLangConfig: enable_torch_compile: bool = False torch_compile_max_bs: int = 32 cuda_graph_max_bs: int | None = None - cuda_graph_bs: List[int] | None = None + cuda_graph_bs: list[int] | None = None torchao_config: str = "" enable_nan_detection: bool = False enable_p2p_check: bool = False @@ -665,8 +663,8 @@ class SGLangConfig: # lora enable_lora: bool | None = None max_lora_rank: int | None = None - lora_target_modules: List[str] | None = None - lora_paths: List[str] | None = None + lora_target_modules: list[str] | None = None + lora_paths: list[str] | None = None max_loaded_loras: int = 1 max_loras_per_batch: int = 1 lora_backend: str = "triton" @@ -717,11 +715,11 @@ def build_cmd( if v is None or v is False or v == "": continue if v is True: - flags.append(f"--{k.replace('_','-')}") + flags.append(f"--{k.replace('_', '-')}") elif isinstance(v, list): - flags.append(f"--{k.replace('_','-')} {' '.join(map(str, v))}") + flags.append(f"--{k.replace('_', '-')} {' '.join(map(str, v))}") else: - flags.append(f"--{k.replace('_','-')} {v}") + flags.append(f"--{k.replace('_', '-')} {v}") return f"python3 -m sglang.launch_server {' '.join(flags)}" @staticmethod @@ -736,11 +734,11 @@ def build_args( node_rank: int = 0, ): # Map "all-linear" to "all" - args: Dict = conf_as_dict(sglang_config) + args: dict = conf_as_dict(sglang_config) if sglang_config.enable_multithread_load or sglang_config.enable_fast_load: - assert pkg_version.is_version_equal( - "sglang", "0.5.2" - ), f"Customized model loading requires exact SGLang version 0.5.2" + assert pkg_version.is_version_equal("sglang", "0.5.2"), ( + "Customized model loading requires exact SGLang version 0.5.2" + ) model_loader_extra_config = dict( enable_multithread_load=sglang_config.enable_multithread_load, enable_fast_load=sglang_config.enable_fast_load, @@ -913,8 +911,8 @@ class WandBConfig: job_type: str | None = None group: str | None = None notes: str | None = None - tags: List[str] | None = None - config: Dict | None = None + tags: list[str] | None = None + config: dict | None = None id_suffix: str | None = "train" @@ -924,7 +922,7 @@ class SwanlabConfig: project: str | None = None name: str | None = None - config: Dict | None = None + config: dict | None = None logdir: str | None = None mode: str | None = "disabled" api_key: str | None = os.getenv("SWANLAB_API_KEY", None) @@ -1021,7 +1019,7 @@ class SchedulerConfig: endpoint: str = field(default="http://localhost:8081") deploy_mode: str = field(default="separation") functioncall_service_domain: str = field(default="http://localhost:8080") - reward_functioncall_config: Dict = field(default_factory=dict) + reward_functioncall_config: dict = field(default_factory=dict) reward_model_path: str = field(default="") reward_model_service_url: str = field(default="http://localhost:30000/classify") @@ -1074,7 +1072,7 @@ class SlurmLauncherConfig: default="--mpi=pmi2 -K --chdir $PWD", metadata={"help": "Additional arguments to pass to the srun command."}, ) - additional_bash_cmds: List[str] | None = field( + additional_bash_cmds: list[str] | None = field( default=None, metadata={ "help": "Additional bash commands to setup the container before running " @@ -1242,7 +1240,7 @@ class PPOConfig(GRPOConfig): critic: PPOCriticConfig = field(default_factory=PPOCriticConfig) -def parse_cli_args(argv: List[str]): +def parse_cli_args(argv: list[str]): parser = argparse.ArgumentParser() parser.add_argument( "--config", help="Path to the main configuration file", required=True @@ -1275,7 +1273,7 @@ def to_structured_cfg(cfg, config_cls): return cfg -def load_expr_config(argv: List[str], config_cls): +def load_expr_config(argv: list[str], config_cls): cfg, config_file = parse_cli_args(argv) cfg = to_structured_cfg(cfg, config_cls=config_cls) cfg = OmegaConf.to_object(cfg) @@ -1303,7 +1301,7 @@ def save_config(cfg, log_dir): os.makedirs(log_dir, exist_ok=True) config_save_path = os.path.join(log_dir, "config.yaml") with open(config_save_path, "w") as f: - config_dict: Dict = asdict(cfg) + config_dict: dict = asdict(cfg) yaml.dump( config_dict, f, diff --git a/docs/algorithms/dapo.md b/docs/algorithms/dapo.md index bb4df74e8..24a29ba07 100644 --- a/docs/algorithms/dapo.md +++ b/docs/algorithms/dapo.md @@ -43,7 +43,7 @@ We only list the different parameters from GRPO here: - `actor.overlong_penalty_factor`: The factor of overlong penalty. - `actor.eps_clip`: The lower bound of clipping, default is `0.2`. - `actor.eps_clip_higher`: The higher bound of clipping. -- `actor.dynamic_sampling`: Define the dynamic sampling strategy, selected from `none`, `static` and `dynamic`. +- `actor.dynamic_sampling_strategy`: Define the dynamic sampling strategy, selected from `none`, `static` and `dynamic`. ### Dynamic Sampling Strategy By default, a group will be filtered out if all tracjectorys in this group have the same reward. You can customize this by `filter_batch_fn_DAPO_per_group` within the `./areal/utils/functional.py` diff --git a/docs/cli_reference.md b/docs/cli_reference.md index beac11e33..cd3f6d1ac 100644 --- a/docs/cli_reference.md +++ b/docs/cli_reference.md @@ -310,56 +310,56 @@ Configuration for model optimization during training. Configuration for PPO actor model, a subclass of a TrainEngine. -| Parameter | Type | Default | Description | -| ------------------------- | ------------------------------------------------- | --------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------- | -| `experiment_name` | string | **Required** | - | -| `trial_name` | string | **Required** | - | -| `path` | string | `""` | Path to HuggingFace checkpoint | -| `attn_impl` | string | `"flash_attention_2"` | Attention implementation for huggingface transformers model. **Choices:** `flash_attention_2` | -| `init_from_scratch` | boolean | `False` | Initialize model weights randomly | -| `is_critic` | boolean | `False` | Whether to use a critic/reward model | -| `mb_spec` | [`MicroBatchSpec`](section-micro-batch) | **Required** | - | -| `pad_to_maximum` | boolean | `False` | Whether to pad each microbatch to the length upper bound specified by mb_spec. Can reduce memory fragmentation but slows down training. | -| `disable_dropout` | boolean | `False` | Disable dropout layers during training | -| `gradient_checkpointing` | boolean | `False` | Enable gradient checkpointing | -| `dtype` | string | `"bfloat16"` | Parameter data type. | -| `grad_reduce_dtype` | string | `"float32"` | Gradient reduction data type. | -| `optimizer` | [`OptimizerConfig`](section-optimizer) \| None | `None` | Optimizer configuration. None means no training. | -| `weight_update_mode` | string | `"disk"` | - | -| `backend` | string | `""` | Training backend (refer to documentation) | -| `fsdp` | [`FSDPEngineConfig`](section-fsdp-engine) | **Required** | - | -| `megatron` | [`MegatronEngineConfig`](section-megatron-engine) | **Required** | - | -| `use_lora` | boolean | `False` | Whether to use LoRA. Only support FSDP. Note that should be enabled together with vLLM/SGLang. | -| `lora_rank` | integer | `32` | lora rank | -| `lora_alpha` | integer | `16` | lora alpha | -| `target_modules` | list of string | **Required** | lora target_modules. | -| `peft_type` | string | `"lora"` | peft method type. Only LoRA is supported for now. | -| `group_size` | integer | `1` | Number of sequences in each group | -| `ppo_n_minibatches` | integer | `4` | Number of minibatches for each PPO update | -| `eps_clip` | float | `0.2` | Clipping factor for policy ratio | -| `eps_clip_higher` | float \| None | `None` | Clipping factor (higher value) for policy ratio. Default is None. When eps_clip_higher is set (decoupled), eps_clip will be used as the lower value. | -| `c_clip` | float \| None | `None` | Dual clipping factor for policy ratio, must be > 1.0. None disables dual clipping. | -| `temperature` | float | `1.0` | Temperature during generation. | -| `reward_norm` | [`NormConfig`](section-norm) \| None | `None` | Normalization configuration for rewards | -| `reward_scaling` | float | `1.0` | Reward scaling factor | -| `reward_bias` | float | `0.0` | Reward bias | -| `reward_clip` | float | `20.0` | Maximum absolute value for reward clipping | -| `overlong_reward_penalty` | boolean | `False` | Penalty for overlong sequences. Used within DAPO. | -| `overlong_tokens` | integer \| None | `None` | Number of tokens in the tail that will receive a penalty | -| `overlong_penalty_factor` | float \| None | `None` | Penalty factor for tokens in the tail | -| `mask_no_eos_with_zero` | boolean | `False` | Mask truncated generations (no EOS token) and exclude from training | -| `discount` | float | `1.0` | Discount factor for future rewards | -| `gae_lambda` | float | `1.0` | Lambda parameter for GAE | -| `adv_norm` | [`NormConfig`](section-norm) \| None | `None` | Normalization configuration for advantages. | -| `kl_ctl` | float | `0.1` | KL divergence coefficient | -| `kl_estimator` | string | `"k1"` | KL divergence estimator **Choices:** `k1`, `k2`, `k3` | -| `recompute_logprob` | boolean | `False` | Recompute log probability and replace the log probability returned by inference. | -| `use_decoupled_loss` | boolean | `False` | Use the decoupled loss. Implicitly enables recompute_logprob. | -| `behav_imp_weight_cap` | float \| None | `None` | Filter out tokens where behav_imp_weight exceeds behav_imp_weight_cap when computing loss. Must be > 1.0. use_decoupled_loss must be true. | -| `dynamic_sampling` | string | `"none"` | Dynamic sampling strategy (within DAPO). Select from `none`, `dynamic` and `static`. See the doc for more details | -| `log_agent_stats` | boolean | `False` | Log statistics for agent trajectories | -| `log_agent_stats_keys` | list of string | **Required** | Keys for logging agent trajectory statistics | -| `max_new_tokens` | integer | `1024` | Maximum number of new tokens to generate | +| Parameter | Type | Default | Description | +| --------------------------- | ------------------------------------------------- | --------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------- | +| `experiment_name` | string | **Required** | - | +| `trial_name` | string | **Required** | - | +| `path` | string | `""` | Path to HuggingFace checkpoint | +| `attn_impl` | string | `"flash_attention_2"` | Attention implementation for huggingface transformers model. **Choices:** `flash_attention_2` | +| `init_from_scratch` | boolean | `False` | Initialize model weights randomly | +| `is_critic` | boolean | `False` | Whether to use a critic/reward model | +| `mb_spec` | [`MicroBatchSpec`](section-micro-batch) | **Required** | - | +| `pad_to_maximum` | boolean | `False` | Whether to pad each microbatch to the length upper bound specified by mb_spec. Can reduce memory fragmentation but slows down training. | +| `disable_dropout` | boolean | `False` | Disable dropout layers during training | +| `gradient_checkpointing` | boolean | `False` | Enable gradient checkpointing | +| `dtype` | string | `"bfloat16"` | Parameter data type. | +| `grad_reduce_dtype` | string | `"float32"` | Gradient reduction data type. | +| `optimizer` | [`OptimizerConfig`](section-optimizer) \| None | `None` | Optimizer configuration. None means no training. | +| `weight_update_mode` | string | `"disk"` | - | +| `backend` | string | `""` | Training backend (refer to documentation) | +| `fsdp` | [`FSDPEngineConfig`](section-fsdp-engine) | **Required** | - | +| `megatron` | [`MegatronEngineConfig`](section-megatron-engine) | **Required** | - | +| `use_lora` | boolean | `False` | Whether to use LoRA. Only support FSDP. Note that should be enabled together with vLLM/SGLang. | +| `lora_rank` | integer | `32` | lora rank | +| `lora_alpha` | integer | `16` | lora alpha | +| `target_modules` | list of string | **Required** | lora target_modules. | +| `peft_type` | string | `"lora"` | peft method type. Only LoRA is supported for now. | +| `group_size` | integer | `1` | Number of sequences in each group | +| `ppo_n_minibatches` | integer | `4` | Number of minibatches for each PPO update | +| `eps_clip` | float | `0.2` | Clipping factor for policy ratio | +| `eps_clip_higher` | float \| None | `None` | Clipping factor (higher value) for policy ratio. Default is None. When eps_clip_higher is set (decoupled), eps_clip will be used as the lower value. | +| `c_clip` | float \| None | `None` | Dual clipping factor for policy ratio, must be > 1.0. None disables dual clipping. | +| `temperature` | float | `1.0` | Temperature during generation. | +| `reward_norm` | [`NormConfig`](section-norm) \| None | `None` | Normalization configuration for rewards | +| `reward_scaling` | float | `1.0` | Reward scaling factor | +| `reward_bias` | float | `0.0` | Reward bias | +| `reward_clip` | float | `20.0` | Maximum absolute value for reward clipping | +| `overlong_reward_penalty` | boolean | `False` | Penalty for overlong sequences. Used within DAPO. | +| `overlong_tokens` | integer \| None | `None` | Number of tokens in the tail that will receive a penalty | +| `overlong_penalty_factor` | float \| None | `None` | Penalty factor for tokens in the tail | +| `mask_no_eos_with_zero` | boolean | `False` | Mask truncated generations (no EOS token) and exclude from training | +| `discount` | float | `1.0` | Discount factor for future rewards | +| `gae_lambda` | float | `1.0` | Lambda parameter for GAE | +| `adv_norm` | [`NormConfig`](section-norm) \| None | `None` | Normalization configuration for advantages. | +| `kl_ctl` | float | `0.1` | KL divergence coefficient | +| `kl_estimator` | string | `"k1"` | KL divergence estimator **Choices:** `k1`, `k2`, `k3` | +| `recompute_logprob` | boolean | `False` | Recompute log probability and replace the log probability returned by inference. | +| `use_decoupled_loss` | boolean | `False` | Use the decoupled loss. Implicitly enables recompute_logprob. | +| `behav_imp_weight_cap` | float \| None | `None` | Filter out tokens where behav_imp_weight exceeds behav_imp_weight_cap when computing loss. Must be > 1.0. use_decoupled_loss must be true. | +| `dynamic_sampling_strategy` | string | `"none"` | Dynamic sampling strategy. Select from `none`, `dynamic` and `static`. See the doc for more details | +| `log_agent_stats` | boolean | `False` | Log statistics for agent trajectories | +| `log_agent_stats_keys` | list of string | **Required** | Keys for logging agent trajectory statistics | +| `max_new_tokens` | integer | `1024` | Maximum number of new tokens to generate | (section-ppo-critic)= @@ -712,7 +712,7 @@ Configuration for SwanLab experiment tracking and monitoring. | --------- | -------------- | ------------ | ----------- | | `project` | string \| None | `None` | - | | `name` | string \| None | `None` | - | -| `config` | `Dict` \| None | `None` | - | +| `config` | `dict` \| None | `None` | - | | `logdir` | string \| None | `None` | - | | `mode` | string \| None | `"disabled"` | - | | `api_key` | string \| None | `None` | - | @@ -745,7 +745,7 @@ Configuration for Weights & Biases experiment tracking. | `group` | string \| None | `None` | - | | `notes` | string \| None | `None` | - | | `tags` | list of string \| None | `None` | - | -| `config` | `Dict` \| None | `None` | - | +| `config` | `dict` \| None | `None` | - | | `id_suffix` | string \| None | `"train"` | - | (section-distributed-data-parallel)= @@ -808,6 +808,6 @@ Configuration for worker scheduling. Used in the single-controller mode. Experim | `endpoint` | string | `"http://localhost:8081"` | - | | `deploy_mode` | string | `"separation"` | - | | `functioncall_service_domain` | string | `"http://localhost:8080"` | - | -| `reward_functioncall_config` | `Dict` | **Required** | - | +| `reward_functioncall_config` | `dict` | **Required** | - | | `reward_model_path` | string | `""` | - | | `reward_model_service_url` | string | `"http://localhost:30000/classify"` | - | diff --git a/examples/experimental/dapo/gsm8k_dapo.py b/examples/experimental/dapo/gsm8k_dapo.py index 31a920994..6e4e5167e 100644 --- a/examples/experimental/dapo/gsm8k_dapo.py +++ b/examples/experimental/dapo/gsm8k_dapo.py @@ -1,4 +1,3 @@ -import dataclasses import os import sys from copy import deepcopy @@ -14,11 +13,9 @@ from areal.platforms import current_platform from areal.utils import seeding, stats_tracker from areal.utils.data import ( - broadcast_tensor_container, concat_padded_tensors, cycle_dataloader, get_batch_size, - tensor_container_to, truncate_dict_to_batch_size, ) from areal.utils.dataloader import create_dataloader @@ -42,7 +39,7 @@ def main(args): config, _ = load_expr_config(args, GRPOConfig) config: GRPOConfig - assert config.actor.dynamic_sampling in ["none", "static", "dynamic"] + assert config.actor.dynamic_sampling_strategy in ["none", "static", "dynamic"] rank = int(os.getenv("RANK")) tokenizer = load_hf_tokenizer(config.tokenizer_path) @@ -64,6 +61,9 @@ def main(args): ) # Create dataset and dataloaders + train_loader_batch_size = ( + config.train_dataset.batch_size // actor.data_parallel_world_size + ) train_dataloader = create_dataloader( train_dataset, rank=actor.data_parallel_rank, @@ -126,7 +126,7 @@ def main(args): ), ) - # Run training. + # Run training saver = Saver(config.saver, ft_spec) stats_logger = StatsLogger(config, ft_spec) evaluator = Evaluator(config.evaluator, ft_spec) @@ -165,52 +165,33 @@ def main(args): # Initialize batch collection with stats_tracker.record_timing("rollout"): - collected_batches, sampling_stats = [], [] + collected_batches, sampling_stats, collected_batches_size = [], [], 0 while True: - if actor.is_data_parallel_head(): - if config.async_training: - new_batch = rollout.prepare_batch( - train_dataloader, - workflow=workflow, - granularity=actor.config.group_size, - # TODO: refactor API in future - should_accept=None, - ) - else: - new_batch = rollout.rollout_batch( - next(data_generator), - workflow=workflow, - granularity=actor.config.group_size, - should_accept=None, - ) - new_batch = tensor_container_to(new_batch, actor.device) - new_batch = broadcast_tensor_container( - new_batch, - src_rank=actor.current_data_parallel_head(), - group=actor.context_and_model_parallel_group, - ) - - # Create barrier to synchronize all rollout processes. - dist.barrier(device_ids=[actor.device.index]) - current_platform.synchronize() - # record nums of group acceptance - stats_logger.commit( - dataclasses.asdict(rollout.workflow_executor.rollout_stat) - ) - - with stats_tracker.record_timing("compute_advantage"): - actor.compute_advantages(new_batch) - log_gpu_stats("compute advantages") + if config.async_training: + new_batch = actor.prepare_batch( + train_dataloader, + granularity=actor.config.group_size, + workflow=workflow, + should_accept=lambda sample: True, + ) + else: + new_batch = actor.rollout_batch( + next(data_generator), + granularity=actor.config.group_size, + workflow=workflow, + should_accept=lambda sample: True, + ) # Collect the batch and process it immediately - if config.actor.dynamic_sampling in ["static", "dynamic"]: + if config.actor.dynamic_sampling_strategy in ["static", "dynamic"]: # Filter the current batch by groups filtered_batch, sampling_stat = filter_batch( filter_batch_fn_DAPO, new_batch, config.actor.group_size ) sampling_stats.append(sampling_stat) + breakpoint() - if config.actor.dynamic_sampling == "static": + if config.actor.dynamic_sampling_strategy == "static": # Statistic sampling: No need to refill for static sampling, result in smaller(variant) batch size batch = filtered_batch # Log sampling statistics for static sampling @@ -218,30 +199,30 @@ def main(args): sampling_stats, epoch, step, global_step, stats_logger ) break - elif config.actor.dynamic_sampling == "dynamic": + elif config.actor.dynamic_sampling_strategy == "dynamic": # Dynamic sampling: keep collecting batches until we reach the target batch size # Add filtered batch to collection collected_batches.append(filtered_batch) + collected_batches_size += get_batch_size(filtered_batch) - # Aggregate all filter/clean batches - aggregated_batch = concat_padded_tensors(collected_batches) - expected_batch_size = get_batch_size(new_batch) - aggregated_batch_size = get_batch_size(aggregated_batch) # Check if we have collected enough samples - if aggregated_batch_size >= expected_batch_size: + if collected_batches_size >= train_loader_batch_size: + aggregated_batch = concat_padded_tensors(collected_batches) # Log sampling statistics for dynamic sampling log_sampling_stats( sampling_stats, epoch, step, global_step, stats_logger ) # Truncate batch to train_batch_size batch = truncate_dict_to_batch_size( - data=aggregated_batch, batch_size=expected_batch_size + data=aggregated_batch, + batch_size=train_loader_batch_size, ) break else: # For non-dynamic sampling, just use the current batch batch = new_batch break + breakpoint() if config.actor.recompute_logprob or config.actor.use_decoupled_loss: with stats_tracker.record_timing("recompute_logp"): @@ -254,6 +235,10 @@ def main(args): batch["ref_logp"] = ref.compute_logp(batch) log_gpu_stats("ref logp") + with stats_tracker.record_timing("compute_advantage"): + actor.compute_advantages(batch) + log_gpu_stats("compute advantages") + with ( stats_tracker.record_timing("train_step"), stats_tracker.scope("grpo_actor"), @@ -293,8 +278,6 @@ def main(args): def evaluate_fn(): if actor.is_data_parallel_head(): - # Stats are logged in workflow - # and will be exported later cnt = 0 for data in valid_dataloader: for item in data: diff --git a/examples/experimental/dapo/gsm8k_dapo.yaml b/examples/experimental/dapo/gsm8k_dapo.yaml index 102064617..a09ee1924 100644 --- a/examples/experimental/dapo/gsm8k_dapo.yaml +++ b/examples/experimental/dapo/gsm8k_dapo.yaml @@ -4,7 +4,7 @@ trial_name: trial0 seed: 1 total_train_epochs: 10 tokenizer_path: ${actor.path} -async_training: true +async_training: false cluster: n_nodes: 1 @@ -67,7 +67,7 @@ actor: recompute_logprob: true use_decoupled_loss: true behav_imp_weight_cap: 5.0 - dynamic_sampling: 'dynamic' + dynamic_sampling_strategy: dynamic reward_norm: mean_level: group std_level: group @@ -102,7 +102,7 @@ sglang: # datasets train_dataset: - batch_size: 256 + batch_size: 16 shuffle: true pin_memory: true num_workers: 4 diff --git a/examples/experimental/dr.grpo/gsm8k_drgrpo.yaml b/examples/experimental/dr.grpo/gsm8k_drgrpo.yaml index 9e90b9a6a..d0d91efc0 100644 --- a/examples/experimental/dr.grpo/gsm8k_drgrpo.yaml +++ b/examples/experimental/dr.grpo/gsm8k_drgrpo.yaml @@ -63,7 +63,7 @@ actor: recompute_logprob: true use_decoupled_loss: true behav_imp_weight_cap: 5.0 - dynamic_sampling: false + dynamic_sampling_strategy: none reward_norm: mean_level: group std_level: null diff --git a/examples/experimental/lite_ppo/gsm8k_liteppo.yaml b/examples/experimental/lite_ppo/gsm8k_liteppo.yaml index 7bae82c94..984aaa991 100644 --- a/examples/experimental/lite_ppo/gsm8k_liteppo.yaml +++ b/examples/experimental/lite_ppo/gsm8k_liteppo.yaml @@ -63,7 +63,7 @@ actor: recompute_logprob: true use_decoupled_loss: true behav_imp_weight_cap: 5.0 - dynamic_sampling: false + dynamic_sampling_strategy: none reward_norm: mean_level: group std_level: batch diff --git a/examples/lora/gsm8k_grpo_lora.yaml b/examples/lora/gsm8k_grpo_lora.yaml index ce114574d..4281021f8 100644 --- a/examples/lora/gsm8k_grpo_lora.yaml +++ b/examples/lora/gsm8k_grpo_lora.yaml @@ -64,7 +64,7 @@ actor: recompute_logprob: true use_decoupled_loss: true behav_imp_weight_cap: 5.0 - dynamic_sampling: false + dynamic_sampling_strategy: none reward_norm: mean_level: group std_level: group diff --git a/examples/math/boba_grpo_vllm.yaml b/examples/math/boba_grpo_vllm.yaml index 9b0654225..9905d87fd 100644 --- a/examples/math/boba_grpo_vllm.yaml +++ b/examples/math/boba_grpo_vllm.yaml @@ -63,7 +63,7 @@ actor: recompute_logprob: true use_decoupled_loss: true behav_imp_weight_cap: 5.0 - dynamic_sampling: false + dynamic_sampling_strategy: none adv_norm: mean_level: batch std_level: batch diff --git a/examples/math/deprecated_gsm8k_grpo.yaml b/examples/math/deprecated_gsm8k_grpo.yaml index 33d98647d..d022038e2 100644 --- a/examples/math/deprecated_gsm8k_grpo.yaml +++ b/examples/math/deprecated_gsm8k_grpo.yaml @@ -66,7 +66,7 @@ actor: recompute_logprob: true use_decoupled_loss: true behav_imp_weight_cap: 5.0 - dynamic_sampling: false + dynamic_sampling_strategy: none adv_norm: # Deprecated configuration, change to reward_norm and apply another adv norm mean_level: group std_level: group diff --git a/examples/math/gsm8k_grpo.yaml b/examples/math/gsm8k_grpo.yaml index 85a119247..62a36b20a 100644 --- a/examples/math/gsm8k_grpo.yaml +++ b/examples/math/gsm8k_grpo.yaml @@ -63,7 +63,7 @@ actor: recompute_logprob: true use_decoupled_loss: true behav_imp_weight_cap: 5.0 - dynamic_sampling: false + dynamic_sampling_strategy: none reward_norm: mean_level: group std_level: group diff --git a/examples/math/gsm8k_ppo.yaml b/examples/math/gsm8k_ppo.yaml index 01cc80662..b4caa6208 100644 --- a/examples/math/gsm8k_ppo.yaml +++ b/examples/math/gsm8k_ppo.yaml @@ -63,7 +63,7 @@ actor: recompute_logprob: true use_decoupled_loss: true behav_imp_weight_cap: 5.0 - dynamic_sampling: false + dynamic_sampling_strategy: none max_new_tokens: ${gconfig.max_new_tokens} critic: diff --git a/examples/math/gsm8k_reinforce.yaml b/examples/math/gsm8k_reinforce.yaml index a8d043a65..c1e68ca69 100644 --- a/examples/math/gsm8k_reinforce.yaml +++ b/examples/math/gsm8k_reinforce.yaml @@ -64,7 +64,7 @@ actor: recompute_logprob: true use_decoupled_loss: true behav_imp_weight_cap: 5.0 - dynamic_sampling: false + dynamic_sampling_strategy: none adv_norm: mean_level: batch std_level: batch diff --git a/examples/math/gsm8k_reinforce_baseline.yaml b/examples/math/gsm8k_reinforce_baseline.yaml index 3fe1c8f47..f599e995c 100644 --- a/examples/math/gsm8k_reinforce_baseline.yaml +++ b/examples/math/gsm8k_reinforce_baseline.yaml @@ -64,7 +64,7 @@ actor: recompute_logprob: true use_decoupled_loss: true behav_imp_weight_cap: 5.0 - dynamic_sampling: false + dynamic_sampling_strategy: none reward_norm: mean_level: group std_level: null diff --git a/examples/math/gsm8k_rloo.yaml b/examples/math/gsm8k_rloo.yaml index 99b12c784..5d50926b1 100644 --- a/examples/math/gsm8k_rloo.yaml +++ b/examples/math/gsm8k_rloo.yaml @@ -63,7 +63,7 @@ actor: recompute_logprob: true use_decoupled_loss: true behav_imp_weight_cap: 5.0 - dynamic_sampling: false + dynamic_sampling_strategy: none reward_norm: mean_level: group mean_leave1out: true diff --git a/examples/multi-turn-math/config.yaml b/examples/multi-turn-math/config.yaml index 27fc735ab..ab1e69c87 100644 --- a/examples/multi-turn-math/config.yaml +++ b/examples/multi-turn-math/config.yaml @@ -63,7 +63,7 @@ actor: recompute_logprob: true use_decoupled_loss: true behav_imp_weight_cap: 5.0 - dynamic_sampling: false + dynamic_sampling_strategy: none adv_norm: mean_level: batch std_level: batch diff --git a/examples/tir/tir_math_config.yaml b/examples/tir/tir_math_config.yaml index 9b8473fe4..82dcbaa5e 100644 --- a/examples/tir/tir_math_config.yaml +++ b/examples/tir/tir_math_config.yaml @@ -60,7 +60,7 @@ actor: recompute_logprob: true use_decoupled_loss: true behav_imp_weight_cap: 5.0 - dynamic_sampling: false + dynamic_sampling_strategy: none reward_norm: mean_level: group std_level: group diff --git a/examples/vlm/clevr_count_70k_grpo.yaml b/examples/vlm/clevr_count_70k_grpo.yaml index b64aa5bc8..05425a44c 100644 --- a/examples/vlm/clevr_count_70k_grpo.yaml +++ b/examples/vlm/clevr_count_70k_grpo.yaml @@ -63,7 +63,7 @@ actor: recompute_logprob: true use_decoupled_loss: true behav_imp_weight_cap: 5.0 - dynamic_sampling: false + dynamic_sampling_strategy: none reward_norm: mean_level: group std_level: group diff --git a/recipe/AEnt/actor.py b/recipe/AEnt/actor.py index 03b808315..7ab2a0765 100644 --- a/recipe/AEnt/actor.py +++ b/recipe/AEnt/actor.py @@ -11,11 +11,12 @@ from areal.utils import stats_tracker from areal.utils.data import split_padded_tensor_dict_into_mb_list from areal.utils.functional import ( - dynamic_sampling, + filter_batch, gather_logprobs, gather_logprobs_entropy, ppo_actor_loss_fn, reward_overlong_penalty, + filter_batch_fn_DAPO ) from recipe.AEnt.aent_args import AEntPPOActorConfig from recipe.AEnt.functional import gather_logprobs_clamped_entropy @@ -39,8 +40,10 @@ def __init__(self, config: AEntPPOActorConfig, engine: TrainEngine): def aent_ppo_update( self, data: TensorDict, global_step: int ) -> List[Dict[str, float]]: - if self.dynamic_sampling and len(data["rewards"]) % self.group_size == 0: - data, sampling_stat = dynamic_sampling(data, self.group_size) + if self.dynamic_sampling_strategy and len(data["rewards"]) % self.group_size == 0: + data, sampling_stat = filter_batch( + filter_batch_fn_DAPO, data, self.group_size + ) attn_mask = data["attention_mask"] loss_mask = data["loss_mask"] diff --git a/recipe/AEnt/configs/gsm8k_aent_grpo.yaml b/recipe/AEnt/configs/gsm8k_aent_grpo.yaml index bf307b5a2..9d2c9d8f9 100644 --- a/recipe/AEnt/configs/gsm8k_aent_grpo.yaml +++ b/recipe/AEnt/configs/gsm8k_aent_grpo.yaml @@ -75,7 +75,7 @@ actor: recompute_logprob: true use_decoupled_loss: true behav_imp_weight_cap: 5.0 - dynamic_sampling: false + dynamic_sampling_strategy: none adv_norm: mean_level: group std_level: group From 1db2f50133303f0e2aeb55c08d90f149740b0bf5 Mon Sep 17 00:00:00 2001 From: ZIYI ZENG <1034337098@qq.com> Date: Fri, 24 Oct 2025 15:33:50 +0800 Subject: [PATCH 21/24] Update recipe/AEnt/actor.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- recipe/AEnt/actor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/recipe/AEnt/actor.py b/recipe/AEnt/actor.py index 7ab2a0765..852351d39 100644 --- a/recipe/AEnt/actor.py +++ b/recipe/AEnt/actor.py @@ -40,7 +40,7 @@ def __init__(self, config: AEntPPOActorConfig, engine: TrainEngine): def aent_ppo_update( self, data: TensorDict, global_step: int ) -> List[Dict[str, float]]: - if self.dynamic_sampling_strategy and len(data["rewards"]) % self.group_size == 0: + if self.config.dynamic_sampling_strategy != "none" and len(data["rewards"]) % self.group_size == 0: data, sampling_stat = filter_batch( filter_batch_fn_DAPO, data, self.group_size ) From 8d8224631e0d67bc89630ec2c78b849416fe44c2 Mon Sep 17 00:00:00 2001 From: Ziyi <1034337098@qq.com> Date: Thu, 23 Oct 2025 22:35:10 -0900 Subject: [PATCH 22/24] . --- docs/algorithms/dapo.md | 5 +++-- examples/experimental/dapo/gsm8k_dapo.py | 2 -- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/docs/algorithms/dapo.md b/docs/algorithms/dapo.md index 24a29ba07..013f2b603 100644 --- a/docs/algorithms/dapo.md +++ b/docs/algorithms/dapo.md @@ -45,8 +45,6 @@ We only list the different parameters from GRPO here: - `actor.eps_clip_higher`: The higher bound of clipping. - `actor.dynamic_sampling_strategy`: Define the dynamic sampling strategy, selected from `none`, `static` and `dynamic`. -### Dynamic Sampling Strategy -By default, a group will be filtered out if all tracjectorys in this group have the same reward. You can customize this by `filter_batch_fn_DAPO_per_group` within the `./areal/utils/functional.py` ### Overlong Penalty @@ -61,6 +59,9 @@ Here we briefly introduce the implementation details of DAPO. - `static`: Only one rollout turn and apply the filter function on it, this may result in variable batch size. - `dynamic`: Enable Multi-turn rollout to keep the *constant batch size*. +If `actor.dynamic_sampling_strategy` set to `static` or `dynamic`, a group will be filtered out if all tracjectorys in this group have the same reward. You can customize this by `filter_batch_fn_DAPO_per_group` within the `./areal/utils/functional.py` + + ## Example Usage diff --git a/examples/experimental/dapo/gsm8k_dapo.py b/examples/experimental/dapo/gsm8k_dapo.py index 6e4e5167e..27fdd3271 100644 --- a/examples/experimental/dapo/gsm8k_dapo.py +++ b/examples/experimental/dapo/gsm8k_dapo.py @@ -189,7 +189,6 @@ def main(args): filter_batch_fn_DAPO, new_batch, config.actor.group_size ) sampling_stats.append(sampling_stat) - breakpoint() if config.actor.dynamic_sampling_strategy == "static": # Statistic sampling: No need to refill for static sampling, result in smaller(variant) batch size @@ -222,7 +221,6 @@ def main(args): # For non-dynamic sampling, just use the current batch batch = new_batch break - breakpoint() if config.actor.recompute_logprob or config.actor.use_decoupled_loss: with stats_tracker.record_timing("recompute_logp"): From f949c0fcc4220d0a78f3b5c6aa9456cb09e2cb52 Mon Sep 17 00:00:00 2001 From: Ziyi <1034337098@qq.com> Date: Sat, 25 Oct 2025 21:34:41 -0900 Subject: [PATCH 23/24] . --- examples/experimental/dapo/gsm8k_dapo.yaml | 1 - examples/experimental/dr.grpo/gsm8k_drgrpo.yaml | 1 - examples/experimental/lite_ppo/gsm8k_liteppo.yaml | 1 - examples/lora/gsm8k_grpo_lora.yaml | 1 - examples/math/boba_grpo_vllm.yaml | 1 - examples/math/deprecated_gsm8k_grpo.yaml | 1 - examples/math/gsm8k_ppo.yaml | 1 - examples/math/gsm8k_reinforce.yaml | 1 - examples/math/gsm8k_reinforce_baseline.yaml | 1 - examples/math/gsm8k_rloo.yaml | 1 - examples/multi-turn-math/config.yaml | 1 - examples/tir/tir_math_config.yaml | 1 - examples/vlm/clevr_count_70k_grpo.yaml | 1 - 13 files changed, 13 deletions(-) diff --git a/examples/experimental/dapo/gsm8k_dapo.yaml b/examples/experimental/dapo/gsm8k_dapo.yaml index a09ee1924..f075f1279 100644 --- a/examples/experimental/dapo/gsm8k_dapo.yaml +++ b/examples/experimental/dapo/gsm8k_dapo.yaml @@ -67,7 +67,6 @@ actor: recompute_logprob: true use_decoupled_loss: true behav_imp_weight_cap: 5.0 - dynamic_sampling_strategy: dynamic reward_norm: mean_level: group std_level: group diff --git a/examples/experimental/dr.grpo/gsm8k_drgrpo.yaml b/examples/experimental/dr.grpo/gsm8k_drgrpo.yaml index d0d91efc0..3f1817c09 100644 --- a/examples/experimental/dr.grpo/gsm8k_drgrpo.yaml +++ b/examples/experimental/dr.grpo/gsm8k_drgrpo.yaml @@ -63,7 +63,6 @@ actor: recompute_logprob: true use_decoupled_loss: true behav_imp_weight_cap: 5.0 - dynamic_sampling_strategy: none reward_norm: mean_level: group std_level: null diff --git a/examples/experimental/lite_ppo/gsm8k_liteppo.yaml b/examples/experimental/lite_ppo/gsm8k_liteppo.yaml index 984aaa991..bf6452385 100644 --- a/examples/experimental/lite_ppo/gsm8k_liteppo.yaml +++ b/examples/experimental/lite_ppo/gsm8k_liteppo.yaml @@ -63,7 +63,6 @@ actor: recompute_logprob: true use_decoupled_loss: true behav_imp_weight_cap: 5.0 - dynamic_sampling_strategy: none reward_norm: mean_level: group std_level: batch diff --git a/examples/lora/gsm8k_grpo_lora.yaml b/examples/lora/gsm8k_grpo_lora.yaml index 4281021f8..2e807d434 100644 --- a/examples/lora/gsm8k_grpo_lora.yaml +++ b/examples/lora/gsm8k_grpo_lora.yaml @@ -64,7 +64,6 @@ actor: recompute_logprob: true use_decoupled_loss: true behav_imp_weight_cap: 5.0 - dynamic_sampling_strategy: none reward_norm: mean_level: group std_level: group diff --git a/examples/math/boba_grpo_vllm.yaml b/examples/math/boba_grpo_vllm.yaml index 9905d87fd..afca6da00 100644 --- a/examples/math/boba_grpo_vllm.yaml +++ b/examples/math/boba_grpo_vllm.yaml @@ -63,7 +63,6 @@ actor: recompute_logprob: true use_decoupled_loss: true behav_imp_weight_cap: 5.0 - dynamic_sampling_strategy: none adv_norm: mean_level: batch std_level: batch diff --git a/examples/math/deprecated_gsm8k_grpo.yaml b/examples/math/deprecated_gsm8k_grpo.yaml index d022038e2..efc91502a 100644 --- a/examples/math/deprecated_gsm8k_grpo.yaml +++ b/examples/math/deprecated_gsm8k_grpo.yaml @@ -66,7 +66,6 @@ actor: recompute_logprob: true use_decoupled_loss: true behav_imp_weight_cap: 5.0 - dynamic_sampling_strategy: none adv_norm: # Deprecated configuration, change to reward_norm and apply another adv norm mean_level: group std_level: group diff --git a/examples/math/gsm8k_ppo.yaml b/examples/math/gsm8k_ppo.yaml index b4caa6208..e44056377 100644 --- a/examples/math/gsm8k_ppo.yaml +++ b/examples/math/gsm8k_ppo.yaml @@ -63,7 +63,6 @@ actor: recompute_logprob: true use_decoupled_loss: true behav_imp_weight_cap: 5.0 - dynamic_sampling_strategy: none max_new_tokens: ${gconfig.max_new_tokens} critic: diff --git a/examples/math/gsm8k_reinforce.yaml b/examples/math/gsm8k_reinforce.yaml index c1e68ca69..5c056eacb 100644 --- a/examples/math/gsm8k_reinforce.yaml +++ b/examples/math/gsm8k_reinforce.yaml @@ -64,7 +64,6 @@ actor: recompute_logprob: true use_decoupled_loss: true behav_imp_weight_cap: 5.0 - dynamic_sampling_strategy: none adv_norm: mean_level: batch std_level: batch diff --git a/examples/math/gsm8k_reinforce_baseline.yaml b/examples/math/gsm8k_reinforce_baseline.yaml index f599e995c..eb3534436 100644 --- a/examples/math/gsm8k_reinforce_baseline.yaml +++ b/examples/math/gsm8k_reinforce_baseline.yaml @@ -64,7 +64,6 @@ actor: recompute_logprob: true use_decoupled_loss: true behav_imp_weight_cap: 5.0 - dynamic_sampling_strategy: none reward_norm: mean_level: group std_level: null diff --git a/examples/math/gsm8k_rloo.yaml b/examples/math/gsm8k_rloo.yaml index 5d50926b1..a4f41a4c1 100644 --- a/examples/math/gsm8k_rloo.yaml +++ b/examples/math/gsm8k_rloo.yaml @@ -63,7 +63,6 @@ actor: recompute_logprob: true use_decoupled_loss: true behav_imp_weight_cap: 5.0 - dynamic_sampling_strategy: none reward_norm: mean_level: group mean_leave1out: true diff --git a/examples/multi-turn-math/config.yaml b/examples/multi-turn-math/config.yaml index ab1e69c87..d4d01aac1 100644 --- a/examples/multi-turn-math/config.yaml +++ b/examples/multi-turn-math/config.yaml @@ -63,7 +63,6 @@ actor: recompute_logprob: true use_decoupled_loss: true behav_imp_weight_cap: 5.0 - dynamic_sampling_strategy: none adv_norm: mean_level: batch std_level: batch diff --git a/examples/tir/tir_math_config.yaml b/examples/tir/tir_math_config.yaml index 82dcbaa5e..06323609e 100644 --- a/examples/tir/tir_math_config.yaml +++ b/examples/tir/tir_math_config.yaml @@ -60,7 +60,6 @@ actor: recompute_logprob: true use_decoupled_loss: true behav_imp_weight_cap: 5.0 - dynamic_sampling_strategy: none reward_norm: mean_level: group std_level: group diff --git a/examples/vlm/clevr_count_70k_grpo.yaml b/examples/vlm/clevr_count_70k_grpo.yaml index 05425a44c..bdfbda561 100644 --- a/examples/vlm/clevr_count_70k_grpo.yaml +++ b/examples/vlm/clevr_count_70k_grpo.yaml @@ -63,7 +63,6 @@ actor: recompute_logprob: true use_decoupled_loss: true behav_imp_weight_cap: 5.0 - dynamic_sampling_strategy: none reward_norm: mean_level: group std_level: group From f89f1fa516b89d0dc4f1f58bee7292bde4ae3490 Mon Sep 17 00:00:00 2001 From: Ziyi <1034337098@qq.com> Date: Sat, 25 Oct 2025 21:37:25 -0900 Subject: [PATCH 24/24] . --- areal/api/cli_args.py | 2 +- docs/cli_reference.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/areal/api/cli_args.py b/areal/api/cli_args.py index 12407bb10..6039b63cb 100644 --- a/areal/api/cli_args.py +++ b/areal/api/cli_args.py @@ -487,7 +487,7 @@ class PPOActorConfig(TrainEngineConfig): dynamic_sampling_strategy: str = field( default="none", metadata={ - "help": "Dynamic sampling strategy. Select from `none`, `dynamic` and `static`. See the doc for more details" + "help": "Dynamic sampling strategy. Select from `none`, `dynamic` and `static`. Only effective when running DAPO script." }, ) diff --git a/docs/cli_reference.md b/docs/cli_reference.md index cd3f6d1ac..1a935a7d7 100644 --- a/docs/cli_reference.md +++ b/docs/cli_reference.md @@ -356,7 +356,7 @@ Configuration for PPO actor model, a subclass of a TrainEngine. | `recompute_logprob` | boolean | `False` | Recompute log probability and replace the log probability returned by inference. | | `use_decoupled_loss` | boolean | `False` | Use the decoupled loss. Implicitly enables recompute_logprob. | | `behav_imp_weight_cap` | float \| None | `None` | Filter out tokens where behav_imp_weight exceeds behav_imp_weight_cap when computing loss. Must be > 1.0. use_decoupled_loss must be true. | -| `dynamic_sampling_strategy` | string | `"none"` | Dynamic sampling strategy. Select from `none`, `dynamic` and `static`. See the doc for more details | +| `dynamic_sampling_strategy` | string | `"none"` | Dynamic sampling strategy. Select from `none`, `dynamic` and `static`. Only effective when running DAPO script. | | `log_agent_stats` | boolean | `False` | Log statistics for agent trajectories | | `log_agent_stats_keys` | list of string | **Required** | Keys for logging agent trajectory statistics | | `max_new_tokens` | integer | `1024` | Maximum number of new tokens to generate |