Skip to content

Commit a80d33f

Browse files
authored
support different kl estimator, support reinforce++ and reinforce++-baseline (#408)
1 parent 20a8acc commit a80d33f

File tree

5 files changed

+380
-4
lines changed

5 files changed

+380
-4
lines changed

areal/api/cli_args.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ class NormConfig:
2323
mean_level: str | None = field(
2424
default="batch",
2525
metadata={
26-
"help": "Mean level for normalization. Choices: batch, group. Omit for no mean normalization."
26+
"help": "Mean level for normalization. None for no mean normalization.",
27+
"choices": ["batch", "group", None],
2728
},
2829
)
2930
mean_leave1out: bool = field(
@@ -33,7 +34,8 @@ class NormConfig:
3334
std_level: str | None = field(
3435
default="batch",
3536
metadata={
36-
"help": "Standard deviation level for normalization. Choices: batch, group. Omit for no std normalization."
37+
"help": "Standard deviation level for normalization. None for no std normalization.",
38+
"choices": ["batch", "group", None],
3739
},
3840
)
3941
std_unbiased: bool = field(
@@ -374,6 +376,10 @@ class PPOActorConfig(TrainEngineConfig):
374376

375377
# KL Control
376378
kl_ctl: float = field(default=0.1, metadata={"help": "KL divergence coefficient"})
379+
kl_estimator: str = field(
380+
default="k1",
381+
metadata={"help": "KL divergence estimator", "choices": ["k1", "k2", "k3"]},
382+
)
377383

378384
# Asynchronous RL
379385
recompute_logprob: bool = field(

areal/engine/ppo/actor.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,11 @@
77
from areal.api.engine_api import TrainEngine
88
from areal.engine.fsdp_engine import FSDPEngine
99
from areal.utils import stats_tracker
10-
from areal.utils.data import Normalization, split_padded_tensor_dict_into_mb_list
10+
from areal.utils.data import (
11+
KLEstimator,
12+
Normalization,
13+
split_padded_tensor_dict_into_mb_list,
14+
)
1115
from areal.utils.functional import (
1216
dynamic_sampling,
1317
gather_logprobs,
@@ -30,6 +34,7 @@ def __init__(self, config: PPOActorConfig, engine: TrainEngine):
3034
self.group_size = config.group_size
3135

3236
self.kl_ctl = config.kl_ctl
37+
self.kl_estimator = KLEstimator(config.kl_estimator)
3338

3439
self.adv_norm = Normalization(config.adv_norm) if config.adv_norm else None
3540
self.reward_norm = (
@@ -110,7 +115,7 @@ def compute_advantages(self, data: Dict[str, Any]) -> None:
110115
attn_mask = data["attention_mask"]
111116
seqlens = attn_mask.sum(-1).long()
112117
seq_no_eos_mask = seqlens == attn_mask.shape[1]
113-
rewards = -self.kl_ctl * (old_logp - ref_logp)
118+
rewards = -self.kl_ctl * self.kl_estimator(old_logp, ref_logp)
114119
kl_rewards = rewards.clone()
115120
# KL rewards at the next token after eos is zero.
116121
rewards[batch_indices, seqlens - 1] = 0

areal/utils/data.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1301,3 +1301,64 @@ def _compute_std(
13011301
if factor.item() == 0:
13021302
return torch.ones_like(x_sum_sq)
13031303
return (x_sum_sq / factor).sqrt()
1304+
1305+
1306+
class KLEstimator:
1307+
"""
1308+
KL divergence estimator, supports k1, k2 and k3.
1309+
"""
1310+
1311+
def __init__(self, kl_estimator: str = "k1", apply_clamp: bool = True):
1312+
self.kl_estimator = kl_estimator
1313+
if kl_estimator not in ["k1", "k2", "k3"]:
1314+
raise ValueError(
1315+
f"Invalid KL estimator: {kl_estimator}. Valid choices: k1, k2, k3"
1316+
)
1317+
self.apply_clamp = apply_clamp
1318+
1319+
def __call__(
1320+
self, log_probs: torch.Tensor, log_probs_base: torch.Tensor
1321+
) -> torch.Tensor:
1322+
return self._compute_approx_kl(
1323+
log_probs, log_probs_base, self.kl_estimator, self.apply_clamp
1324+
)
1325+
1326+
# adapted from https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/models/utils.py#L7
1327+
@staticmethod
1328+
def _compute_approx_kl(
1329+
log_probs: torch.Tensor,
1330+
log_probs_base: torch.Tensor,
1331+
kl_estimator: str = "k1",
1332+
apply_clamp: bool = True,
1333+
) -> torch.Tensor:
1334+
"""
1335+
Compute the approximate KL divergence between two distributions.
1336+
Schulman blog: http://joschu.net/blog/kl-approx.html
1337+
1338+
Args:
1339+
log_probs: Log probabilities of the new distribution.
1340+
log_probs_base: Log probabilities of the base distribution.
1341+
"""
1342+
1343+
if kl_estimator == "k1":
1344+
log_ratio = log_probs.float() - log_probs_base.float()
1345+
1346+
# The k2 estimator is the non negative kl approximation in
1347+
# http://joschu.net/blog/kl-approx.html
1348+
# The k2_loss is approximately equivalent to the
1349+
# one-step KL divergence penalty with the k1 estimator
1350+
# used in https://arxiv.org/pdf/2310.10505.
1351+
if kl_estimator == "k2":
1352+
log_ratio = log_probs.float() - log_probs_base.float()
1353+
log_ratio = log_ratio**2 / 2.0
1354+
1355+
# The k3 estimator is the non negative kl approximation in
1356+
# http://joschu.net/blog/kl-approx.html
1357+
if kl_estimator == "k3":
1358+
log_ratio = log_probs.float() - log_probs_base.float()
1359+
log_ratio = -log_ratio
1360+
log_ratio = log_ratio.exp() - 1 - log_ratio
1361+
1362+
if apply_clamp:
1363+
log_ratio = log_ratio.clamp(min=-10, max=10)
1364+
return log_ratio

examples/math/gsm8k_reinforce.yaml

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
experiment_name: gsm8k-reinforce-plus-plus
2+
trial_name: trial0
3+
4+
seed: 1
5+
total_train_epochs: 10
6+
tokenizer_path: ${actor.path}
7+
async_training: true
8+
9+
cluster:
10+
n_nodes: 1
11+
n_gpus_per_node: 8
12+
fileroot: /tmp/areal/experiments
13+
name_resolve:
14+
type: nfs
15+
nfs_record_root: /tmp/areal/name_resolve
16+
17+
allocation_mode: sglang.d4p1t1+d4p1t1
18+
19+
rollout:
20+
experiment_name: ${experiment_name}
21+
trial_name: ${trial_name}
22+
max_concurrent_rollouts: 256
23+
queue_size: null
24+
consumer_batch_size: ${train_dataset.batch_size}
25+
max_head_offpolicyness: 2
26+
enable_rollout_tracing: false
27+
28+
gconfig:
29+
n_samples: 4
30+
min_new_tokens: 0
31+
max_new_tokens: 1024
32+
greedy: false
33+
temperature: 1.0
34+
35+
actor:
36+
experiment_name: ${experiment_name}
37+
trial_name: ${trial_name}
38+
path: Qwen/Qwen2.5-1.5B-Instruct
39+
init_from_scratch: false
40+
disable_dropout: true
41+
gradient_checkpointing: false
42+
dtype: bfloat16
43+
mb_spec:
44+
max_tokens_per_mb: 10240
45+
optimizer:
46+
type: adam
47+
lr: 1.70e-5
48+
weight_decay: 0.017
49+
beta1: 0.9
50+
beta2: 0.999
51+
eps: 1e-8
52+
lr_scheduler_type: constant
53+
gradient_clipping: 1.0
54+
warmup_steps_proportion: 0.001
55+
backend: fsdp
56+
group_size: ${gconfig.n_samples}
57+
eps_clip: 0.4
58+
temperature: ${gconfig.temperature}
59+
reward_scaling: 10.0
60+
reward_bias: -0.5
61+
kl_ctl: 0.001
62+
kl_estimator: k1
63+
ppo_n_minibatches: 1
64+
recompute_logprob: true
65+
use_decoupled_loss: true
66+
behav_imp_weight_cap: 5.0
67+
dynamic_sampling: false
68+
adv_norm:
69+
mean_level: batch
70+
std_level: batch
71+
max_new_tokens: ${gconfig.max_new_tokens}
72+
73+
ref:
74+
experiment_name: ${experiment_name}
75+
trial_name: ${trial_name}
76+
path: ${actor.path}
77+
init_from_scratch: false
78+
disable_dropout: true
79+
dtype: ${actor.dtype}
80+
mb_spec:
81+
max_tokens_per_mb: 10240
82+
optimizer: null
83+
backend: fsdp
84+
85+
# SGLang
86+
sglang:
87+
model_path: ${actor.path}
88+
random_seed: ${seed}
89+
skip_tokenizer_init: true
90+
dtype: ${actor.dtype}
91+
max_running_requests: null
92+
context_length: 32768
93+
mem_fraction_static: 0.8
94+
95+
# datasets
96+
train_dataset:
97+
batch_size: 256
98+
shuffle: true
99+
pin_memory: true
100+
num_workers: 4
101+
path: openai/gsm8k
102+
type: rl
103+
max_length: 1024
104+
105+
valid_dataset:
106+
batch_size: 256
107+
shuffle: true
108+
pin_memory: true
109+
num_workers: 4
110+
path: openai/gsm8k
111+
type: rl
112+
113+
# Utilities
114+
saver:
115+
experiment_name: ${experiment_name}
116+
trial_name: ${trial_name}
117+
fileroot: ${cluster.fileroot}
118+
freq_epochs: 1
119+
freq_steps: null
120+
freq_secs: null
121+
122+
recover:
123+
mode: disabled
124+
experiment_name: ${experiment_name}
125+
trial_name: ${trial_name}
126+
fileroot: ${cluster.fileroot}
127+
freq_epochs: 1
128+
freq_steps: null
129+
freq_secs: 3600
130+
131+
evaluator:
132+
experiment_name: ${experiment_name}
133+
trial_name: ${trial_name}
134+
fileroot: ${cluster.fileroot}
135+
freq_epochs: 1
136+
freq_steps: null
137+
freq_secs: null
138+
139+
stats_logger:
140+
experiment_name: ${experiment_name}
141+
trial_name: ${trial_name}
142+
fileroot: ${cluster.fileroot}
143+
wandb:
144+
mode: disabled
145+
146+
launcher:
147+
inference_server_cpus_per_gpu: 4
148+
inference_server_mem_per_gpu: 32768
149+
trainer_cpus_per_gpu: 4
150+
trainer_mem_per_gpu: 32768

0 commit comments

Comments
 (0)