Skip to content
29 changes: 29 additions & 0 deletions examples/entropy/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Entropy dynamics of RL training

This example shows the two algorithms **Clip_B** and **Clip_V** from the work [On the Entropy Dynamics in Reinforcement Fine-Tuning of Large Language Models](https://arxiv.org/pdf/2602.03392).

## Data Preparation

We utilize the [DAPO-Math-17k](https://huggingface.co/datasets/open-r1/DAPO-Math-17k-Processed) dataset as our training set. We exclude 500 questions from the training set to form the validation set (denoted by dapo-validation-500).
The training set is filtered out samples from the training set with excessively high (≥ 15/16) or low (≤ 1/16) pass rates, as evaluated by Qwen2.5-7B-Instruct.

## Clip_B Experiment

1. Apply the patch to keep entropy information in the trainer batch:

```bash
cd /path/to/Trinity-RFT
git apply examples/entropy/clipb_trainer.patch
```

2. Update the dataset paths in the config file [`clipb.yaml`](clipb.yaml) to point to your local data.

3. Run the experiment:

```bash
trinity run examples/entropy/clipb.yaml
```

## Clip_V Implementation

Coming soon.
100 changes: 100 additions & 0 deletions examples/entropy/clipb.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
project: math_dapo
name: clipb_example
checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
model:
model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-7B-Instruct}
max_prompt_tokens: 1024
max_response_tokens: 7168
algorithm:
algorithm_type: grpo_verl
advantage_fn: clipb
advantage_fn_args:
mu: 2.5
repeat_times: 16
kl_loss_fn_args:
kl_coef: 0.0
cluster:
node_num: 1
gpu_per_node: 8
buffer:
total_epochs: 20
batch_size: 64
explorer_input:
taskset:
name: dapo_235
storage_type: file
path: ${oc.env:TRINITY_TASKSET_PATH} # processed DAPO-Math-17k
format:
prompt_key: 'question'
response_key: 'ground_truth'
rollout_args:
temperature: 1.0
logprobs: 20
eval_tasksets:
- name: dapo-validation-500
storage_type: file
path: '/path/to/dapo-validation' # validation samples from DAPO-Math-17k
split: 'test'
repeat_times: 32
format:
prompt_key: 'question'
response_key: 'ground_truth'
rollout_args:
temperature: 0.7
- name: amc23
storage_type: file
path: math-ai/amc23 # Path to the AMC23 dataset
split: 'test'
repeat_times: 32
format:
prompt_key: 'question'
response_key: 'answer'
rollout_args:
temperature: 0.7
- name: aime24
storage_type: file
path: HuggingFaceH4/aime_2024 # Path to the AIME2024 dataset
split: 'train'
repeat_times: 32
format:
prompt_key: 'problem'
response_key: 'answer'
rollout_args:
temperature: 0.7
- name : aime25
storage_type: file
path: math-ai/aime25 # Path to the AIME2025 dataset
split: 'test'
repeat_times: 32
format:
prompt_key: 'problem'
response_key: 'answer'
rollout_args:
temperature: 0.7
default_workflow_type: 'async_math_workflow'
default_reward_fn_type: 'math_boxed_reward'
trainer_input:
experience_buffer:
name: math_buffer
storage_type: queue
max_read_timeout: 7200
explorer:
eval_interval: 20
eval_on_startup: true
runner_per_model: 8
rollout_model:
engine_type: vllm_async
engine_num: 4
tensor_parallel_size: 1
seed: 42
trainer:
trainer_type: 'verl'
save_interval: 200
trainer_config:
algorithm:
rollout_correction:
bypass_mode: false
synchronizer:
sync_method: 'nccl'
sync_interval: 1
sync_timeout: 3200
11 changes: 11 additions & 0 deletions examples/entropy/clipb_trainer.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
--- a/trinity/trainer/verl_trainer.py
+++ b/trinity/trainer/verl_trainer.py
@@ -501,7 +501,8 @@ class VerlPPOTrainerWrapper(RayPPOTrainer, TrainEngineWrapper):
}
metrics.update(old_log_prob_metrics)
- old_log_prob.batch.pop("entropys")
+ # Keep entropys in batch so advantage_fn (e.g. Clip_B) can use it
+ # old_log_prob.batch.pop("entropys")
batch = batch.union(old_log_prob)
if "rollout_log_probs" in batch.batch.keys():
# TODO: we may want to add diff of probs too.
1 change: 1 addition & 0 deletions trinity/algorithm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
"multi_step_grpo": "trinity.algorithm.algorithm.MultiStepGRPOAlgorithm",
"on_policy_distill": "trinity.algorithm.algorithm.OnPolicyDistillAlgorithm",
"jsd": "trinity.algorithm.algorithm.JSDAlgorithm",
"grpo_verl": "trinity.algorithm.algorithm.GRPOverlAlgorithm",
},
)

Expand Down
1 change: 1 addition & 0 deletions trinity/algorithm/advantage_fn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"rec": "trinity.algorithm.advantage_fn.rec_advantage.RECGroupedAdvantage",
"on_policy_distill": "trinity.algorithm.advantage_fn.on_policy_distill_advantage.OnPolicyDistillAdvantage",
"jsd": "trinity.algorithm.advantage_fn.jsd_advantage.JSDAdvantage",
"clipb": "trinity.algorithm.advantage_fn.clipb_advantage.ClipBAdvantageFn",
},
)

Expand Down
152 changes: 152 additions & 0 deletions trinity/algorithm/advantage_fn/clipb_advantage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# -*- coding: utf-8 -*-
"""Advantage computation for Clip_B
Ref: https://arxiv.org/pdf/2602.03392"""

from collections import defaultdict
from typing import TYPE_CHECKING, Dict, Tuple

import torch

if TYPE_CHECKING:
from verl import DataProto

from trinity.algorithm.advantage_fn.advantage_fn import AdvantageFn


class ClipBAdvantageFn(AdvantageFn):
"""Clip_B advantage: keep all positive-advantage tokens,
one-side clip negative-advantage tokens by entropy signal."""

def __init__(
self,
epsilon: float = 1e-6,
mu: float = 2.5,
) -> None:
self.epsilon = epsilon
self.mu = mu

def __call__(
self,
exps: "DataProto",
**kwargs,
) -> Tuple["DataProto", Dict]:
"""
Compute advantage for Clip_B.
exps should contain the following fields:
- token_level_rewards: `(torch.Tensor)`
shape: (bs, response_length)
- response_mask: `(torch.Tensor)`
shape: (bs, response_length)
- uid: `(torch.Tensor)`
shape: (bs,)
- rollout_log_probs: `(torch.Tensor)`
shape: (bs, response_length)
- entropys: `(torch.Tensor)`
shape: (bs, response_length)
Returns:
exps: DataProto with advantages and returns added
metrics: Dict with clipping metrics
"""
token_level_rewards = exps.batch["token_level_rewards"]
response_mask = exps.batch["response_mask"]
index = exps.non_tensor_batch["uid"]

response_length = token_level_rewards.shape[-1]
scores = token_level_rewards.sum(dim=-1)

id2score = defaultdict(list)
id2mean = {}
id2std = {}
with torch.no_grad():
bsz = scores.shape[0]
for i in range(bsz):
id2score[index[i]].append(scores[i])

for idx in id2score:
if len(id2score[idx]) == 1:
id2mean[idx] = torch.tensor(0.0, dtype=scores.dtype, device=scores.device)
id2std[idx] = torch.tensor(1.0, dtype=scores.dtype, device=scores.device)
elif len(id2score[idx]) > 1:
group_scores = torch.stack(id2score[idx]).to(
dtype=scores.dtype, device=scores.device
)
id2mean[idx] = torch.mean(group_scores)
id2std[idx] = torch.std(group_scores)
else:
raise ValueError(f"no score in prompt index: {idx}")

for i in range(bsz):
scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + self.epsilon)
scores = scores.unsqueeze(-1).tile([1, response_length]) * response_mask

exps.batch["advantages"] = scores
exps.batch["returns"] = scores.clone()

# --- BEGIN: token filtering logic ---
# Use recomputed logprobs & entropy from current model (not rollout)
LP = exps.batch["rollout_log_probs"] # [B, T], recomputed logprobs
H = exps.batch["entropys"] # [B, T], recomputed entropy
M = response_mask # [B, T], mask of valid tokens
p = LP.exp() # [B, T], probability of valid tokens
S = p * (H + LP) # [B, T], indicator

# Detach for constructing clip mask (no gradient needed)
xS = S.detach().to(torch.float32) # [B, T]
m = M.to(torch.float32) # [B, T]

# Masked global mean & variance (population variance, denominator = n)
n = m.sum().clamp_min(1.0)
ES = (xS * m).sum() / n # scalar
varS = ((xS - ES) ** 2 * m).sum() / n # scalar
stdS = varS.sqrt() # scalar

# Centered signal
z = xS - ES # [B, T]

# if stdS is too small, keep all tokens; otherwise
# keep all positive-advantage tokens; one-side clip negative-advantage tokens
if stdS.item() < 1e-12:
keep = torch.ones_like(M, dtype=M.dtype) # all kept
else:
A = exps.batch["advantages"].detach().to(torch.float32) # [B, T]
pos_mask = A > 0
neg_mask = A < 0

keep_pos = torch.ones_like(pos_mask, dtype=torch.bool) # positive: all kept
keep_neg = z >= -(self.mu * stdS) # negative: lower-side clip
keep_zero = torch.ones_like(pos_mask, dtype=torch.bool) # zero: all kept

keep_bool = torch.where(pos_mask, keep_pos, torch.where(neg_mask, keep_neg, keep_zero))
keep = keep_bool.to(M.dtype)

M_clipped = M * keep
exps.batch["response_mask"] = M_clipped
# --- END: token filtering logic ---

# Monitoring metrics
total_tokens = m.sum().clamp_min(1.0)
frac_clipped = 1.0 - (M_clipped.to(torch.float32).sum() / total_tokens).item()

A = exps.batch["advantages"].detach().to(torch.float32)
pos_mask = (A > 0).to(M.dtype)
neg_mask = (A < 0).to(M.dtype)
total_pos = (M * pos_mask).to(torch.float32).sum().clamp_min(1.0)
total_neg = (M * neg_mask).to(torch.float32).sum().clamp_min(1.0)
frac_clipped_pos = 1.0 - ((M_clipped * pos_mask).to(torch.float32).sum() / total_pos).item()
frac_clipped_neg = 1.0 - ((M_clipped * neg_mask).to(torch.float32).sum() / total_neg).item()

metrics = {
"frac_clipped": frac_clipped,
"frac_clipped_pos": frac_clipped_pos,
"frac_clipped_neg": frac_clipped_neg,
"ES": ES.item(),
"varS": varS.item(),
}
return exps, metrics

@classmethod
def default_args(cls) -> Dict:
return {
"epsilon": 1e-6,
"mu": 2.5,
}
22 changes: 22 additions & 0 deletions trinity/algorithm/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,3 +540,25 @@ def default_config(cls) -> Dict:
"kl_loss_fn": "none",
"entropy_loss_fn": "none",
}


class GRPOverlAlgorithm(AlgorithmType):
"""GRPO algorithm, but advantage computation is done in trainer."""

use_critic: bool = False
use_reference: bool = True
compute_advantage_in_trainer: bool = True
can_balance_batch: bool = True
schema: str = "experience"

@classmethod
def default_config(cls) -> Dict:
return {
"repeat_times": 2,
"advantage_fn": "grpo",
"sample_strategy": "default",
"policy_loss_fn": "ppo",
"kl_penalty_fn": "none",
"kl_loss_fn": "k2",
"entropy_loss_fn": "default",
}
1 change: 1 addition & 0 deletions trinity/common/verl_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ class Actor:
router_replay: RouterReplayConfig = field(default_factory=RouterReplayConfig)
# do not set
loss_agg_mode: str = "token-mean"
loss_scale_factor: Optional[float] = None
clip_ratio: float = 0.2
clip_ratio_low: Optional[float] = None
clip_ratio_high: Optional[float] = None
Expand Down