Skip to content

Commit 57afe2f

Browse files
Co-authored-by: Yanxi Chen <[email protected]>
1 parent 375ad7d commit 57afe2f

File tree

9 files changed

+211
-3
lines changed

9 files changed

+211
-3
lines changed

examples/asymre_gsm8k/gsm8k.yaml

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
1-
project: sync_offset_0_sync20
2-
name: asymre-gsm8k_shift-0.1
3-
checkpoint_root_dir: /PATH/TO/CHECKPOINT/
1+
# Configuration file for the AsymRE GSM8k project.
2+
# REINFORCE for off-Policy Reinforcement Learning: Balancing positive and negative rewards
3+
# https://arxiv.org/abs/2506.20520.
4+
5+
project: "Trinity-RFT-GSM8K"
6+
name: asymre_gsm8k
7+
checkpoint_root_dir: /PATH/TO/CHECKPOINT/
48
model:
59
model_path: /PATH/TO/MODEL/
610
max_response_tokens: 1024

examples/sppo_gsm8k/README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# Example: sPPO on GSM8k dataset
2+
3+
This example shows the usage of [sPPO](https://arxiv.org/abs/2108.05828) on the [GSM8k dataset](https://huggingface.co/datasets/openai/gsm8k).
4+
5+
For more detailed information, please refer to the [documentation](../../docs/sphinx_doc/source/tutorial/example_reasoning_basic.md).
6+
7+
The config files are located in [`gsm8k.yaml`](gsm8k.yaml) and [`train_gsm8k.yaml`](train_gsm8k.yaml).

examples/sppo_gsm8k/gsm8k.yaml

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Configuration file for the sPPO GSM8k project.
2+
# A general class of surrogate functions for stable and efficient reinforcement learning
3+
# https://arxiv.org/abs/2108.05828.
4+
5+
project: "Trinity-RFT-GSM8K"
6+
name: sppo_gsm8k
7+
checkpoint_root_dir: /PATH/TO/CHECKPOINT/
8+
model:
9+
model_path: /PATH/TO/MODEL/
10+
max_response_tokens: 1024
11+
max_model_len: 1280
12+
algorithm:
13+
algorithm_type: sppo
14+
policy_loss_fn_args:
15+
epsilon: 0.1
16+
repeat_times: 8
17+
cluster:
18+
node_num: 1
19+
gpu_per_node: 8
20+
buffer:
21+
total_steps: 100
22+
batch_size: 96
23+
max_retry_times: 3
24+
max_retry_interval: 1
25+
explorer_input:
26+
taskset:
27+
name: gsm8k
28+
storage_type: file
29+
path: /PATH/TO/DATASET/
30+
split: train
31+
format:
32+
prompt_key: question
33+
response_key: answer
34+
rollout_args:
35+
temperature: 1.0
36+
eval_tasksets:
37+
- name: gsm8k-eval
38+
storage_type: file
39+
path: /PATH/TO/DATASET/
40+
split: test
41+
format:
42+
prompt_key: question
43+
response_key: answer
44+
default_workflow_type: math_workflow
45+
trainer_input:
46+
experience_buffer:
47+
name: gsm8k_buffer
48+
storage_type: queue
49+
explorer:
50+
eval_interval: 20
51+
runner_num: 64
52+
rollout_model:
53+
engine_type: vllm_async
54+
engine_num: 4
55+
tensor_parallel_size: 1
56+
enable_prefix_caching: false
57+
enforce_eager: true
58+
dtype: bfloat16
59+
seed: 42
60+
synchronizer:
61+
sync_method: nccl
62+
sync_interval: 20
63+
sync_timeout: 1200
64+
sync_offset: 0
65+
trainer:
66+
trainer_type: verl
67+
trainer_config_path: examples/sppo_gsm8k/train_gsm8k.yaml
68+
save_interval: 100
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
actor_rollout_ref:
2+
hybrid_engine: True
3+
model:
4+
external_lib: null
5+
override_config: { }
6+
enable_gradient_checkpointing: True
7+
use_remove_padding: True # False
8+
actor:
9+
strategy: fsdp # This is for backward-compatibility
10+
ppo_micro_batch_size_per_gpu: 8
11+
use_dynamic_bsz: True # False
12+
ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length}
13+
grad_clip: 1.0
14+
ppo_epochs: 1
15+
shuffle: False
16+
ulysses_sequence_parallel_size: 1 # sp size
17+
optim:
18+
lr: 1e-6
19+
lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
20+
# min_lr_ratio: null # only useful for warmup with cosine
21+
warmup_style: constant # select from constant/cosine
22+
total_training_steps: -1
23+
fsdp_config:
24+
wrap_policy:
25+
# transformer_layer_cls_to_wrap: None
26+
min_num_params: 0
27+
param_offload: False
28+
optimizer_offload: False
29+
fsdp_size: -1
30+
ref:
31+
fsdp_config:
32+
param_offload: False
33+
wrap_policy:
34+
# transformer_layer_cls_to_wrap: None
35+
min_num_params: 0
36+
log_prob_micro_batch_size_per_gpu: 16
37+
log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
38+
log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
39+
ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size
40+
41+
trainer:
42+
balance_batch: True
43+
# auto: find the last ckpt to resume. If can't find, start from scratch
44+
resume_mode: auto # or auto or resume_path if
45+
default_hdfs_dir: null
46+
remove_previous_ckpt_in_save: False
47+
del_local_ckpt_after_load: False
48+
val_before_train: False

trinity/algorithm/advantage_fn/asymre_advantage.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ def calculate_group_advantage(
113113
exp.returns = exp.advantages.clone()
114114
metrics = {
115115
"group_baseline": group_baseline.item(),
116+
"reward_mean": group_baseline.item() - self.baseline_shift,
116117
}
117118
return exps, metrics
118119

trinity/algorithm/advantage_fn/opmd_advantage.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ def calculate_group_advantage(
136136
exp.returns = exp.advantages.clone()
137137
metrics = {
138138
"group_baseline": group_baseline.item(),
139+
"reward_mean": torch.mean(group_rewards).item(),
139140
}
140141
return exps, metrics
141142

trinity/algorithm/algorithm.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,3 +269,26 @@ def default_config(cls) -> Dict:
269269
"kl_loss_fn": "none",
270270
"entropy_loss_fn": "none",
271271
}
272+
273+
274+
@ALGORITHM_TYPE.register_module("sppo")
275+
class sPPOAlgorithm(AlgorithmType):
276+
"""sPPO Algorithm."""
277+
278+
use_critic: bool = False
279+
use_reference: bool = False
280+
compute_advantage_in_trainer: bool = False
281+
can_balance_batch: bool = True
282+
schema: str = "experience"
283+
284+
@classmethod
285+
def default_config(cls) -> Dict:
286+
return {
287+
"repeat_times": 2,
288+
"sample_strategy": "warmup",
289+
"policy_loss_fn": "sppo",
290+
"advantage_fn": "opmd",
291+
"kl_penalty_fn": "none",
292+
"kl_loss_fn": "none",
293+
"entropy_loss_fn": "none",
294+
}

trinity/algorithm/policy_loss_fn/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn
1111
from trinity.algorithm.policy_loss_fn.ppo_policy_loss import PPOPolicyLossFn
1212
from trinity.algorithm.policy_loss_fn.sft_loss import SFTLossFn
13+
from trinity.algorithm.policy_loss_fn.sppo_loss_fn import sPPOPolicyLossFn
1314

1415
__all__ = [
1516
"POLICY_LOSS_FN",
@@ -23,4 +24,5 @@
2324
"MIXCHORDPolicyLossFn",
2425
"SFTISLossFn",
2526
"SFTPhiLossFn",
27+
"sPPOPolicyLossFn",
2628
]
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
"""sPPO-token policy loss function.
2+
Relevant paper: https://arxiv.org/abs/2108.05828.
3+
"""
4+
5+
from typing import Dict, Tuple
6+
7+
import torch
8+
9+
from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn
10+
from trinity.algorithm.utils import masked_mean
11+
12+
13+
@POLICY_LOSS_FN.register_module("sppo")
14+
class sPPOPolicyLossFn(PolicyLossFn):
15+
def __init__(
16+
self,
17+
backend: str = "verl",
18+
epsilon: float = 0.3,
19+
) -> None:
20+
super().__init__(backend=backend)
21+
self.epsilon = epsilon
22+
23+
def __call__( # type: ignore
24+
self,
25+
logprob: torch.Tensor, # [batch_size, seq_len]
26+
old_logprob: torch.Tensor, # [batch_size, seq_len]
27+
action_mask: torch.Tensor, # [batch_size, seq_len]
28+
advantages: torch.Tensor, # [batch_size, seq_len]
29+
**kwargs,
30+
) -> Tuple[torch.Tensor, Dict]:
31+
"""Calculate sPPO loss.
32+
The formula is as follows:
33+
advantages*log(clip(ratio, 1/(1+epsilon), 1+epsilon))
34+
ratio = exp(logprob - old_logprob)
35+
"""
36+
#
37+
# token-wise
38+
ratio = torch.exp(logprob - old_logprob).detach()
39+
is_in_range = (ratio >= (1 / (1 + self.epsilon))) * (ratio <= (1 + self.epsilon))
40+
is_clipped_mask = ~is_in_range
41+
pg_losses = -advantages * (logprob - old_logprob) * is_in_range.float()
42+
pg_loss = masked_mean(pg_losses, action_mask)
43+
pg_clipfrac = masked_mean(is_clipped_mask.float(), action_mask)
44+
metrics = {
45+
"pg_clipfrac": pg_clipfrac.item(),
46+
"pg_loss": pg_loss.detach().item(),
47+
}
48+
return pg_loss, metrics
49+
50+
@classmethod
51+
def default_args(cls) -> Dict:
52+
return {
53+
"epsilon": 0.3,
54+
}

0 commit comments

Comments
 (0)