Skip to content

Commit f88181c

Browse files
authored
Add GSPO-style REC variant (#380)
1 parent 106e69c commit f88181c

File tree

5 files changed

+88
-16
lines changed

5 files changed

+88
-16
lines changed

examples/rec_gsm8k/README.md

Lines changed: 48 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -83,11 +83,10 @@ algorithm:
8383
epsilon_high: 0.2
8484
clip_mode: "one-side"
8585
weight: "none"
86-
temp: 1.0
87-
regularizer: "none"
88-
regularizer_coef: 0.0
8986
advantage_fn_args:
9087
std_normalize: false
88+
kl_loss_fn_args:
89+
kl_coef: 0.0
9190
```
9291

9392
**REC-OneSide-IS:**
@@ -100,11 +99,42 @@ algorithm:
10099
epsilon_high: 0.2
101100
clip_mode: "one-side"
102101
weight: "importance_sampling"
103-
temp: 1.0
104-
regularizer: "none"
105-
regularizer_coef: 0.0
106102
advantage_fn_args:
107103
std_normalize: false
104+
kl_loss_fn_args:
105+
kl_coef: 0.0
106+
```
107+
108+
**REC-GSPO-NoIS:**
109+
110+
```
111+
algorithm:
112+
algorithm_type: rec
113+
policy_loss_fn_args:
114+
epsilon_low: 3e-4
115+
epsilon_high: 4e-4
116+
clip_mode: "gspo-one-side"
117+
weight: "none"
118+
advantage_fn_args:
119+
std_normalize: false
120+
kl_loss_fn_args:
121+
kl_coef: 0.0
122+
```
123+
124+
**REC-GSPO-IS:**
125+
126+
```
127+
algorithm:
128+
algorithm_type: rec
129+
policy_loss_fn_args:
130+
epsilon_low: 3e-4
131+
epsilon_high: 4e-4
132+
clip_mode: "gspo-one-side"
133+
weight: "gspo_importance_sampling"
134+
advantage_fn_args:
135+
std_normalize: false
136+
kl_loss_fn_args:
137+
kl_coef: 0.0
108138
```
109139

110140
**REC-TwoSide-IS:**
@@ -122,6 +152,8 @@ algorithm:
122152
regularizer_coef: 0.0
123153
advantage_fn_args:
124154
std_normalize: false
155+
kl_loss_fn_args:
156+
kl_coef: 0.0
125157
```
126158

127159
**REC-Ring-NoIS:**
@@ -141,6 +173,8 @@ algorithm:
141173
regularizer_coef: 0.0
142174
advantage_fn_args:
143175
std_normalize: false
176+
kl_loss_fn_args:
177+
kl_coef: 0.0
144178
```
145179

146180
### REP family
@@ -159,6 +193,8 @@ algorithm:
159193
regularizer_coef: 0.1
160194
advantage_fn_args:
161195
std_normalize: false
196+
kl_loss_fn_args:
197+
kl_coef: 0.0
162198
```
163199

164200

@@ -174,6 +210,8 @@ algorithm:
174210
regularizer_coef: 0.1
175211
advantage_fn_args:
176212
std_normalize: false
213+
kl_loss_fn_args:
214+
kl_coef: 0.0
177215
```
178216

179217
### RED family
@@ -191,6 +229,8 @@ algorithm:
191229
advantage_fn_args:
192230
std_normalize: false
193231
drop: "balance"
232+
kl_loss_fn_args:
233+
kl_coef: 0.0
194234
```
195235

196236

@@ -206,6 +246,8 @@ algorithm:
206246
temp: 1.0
207247
advantage_fn_args:
208248
std_normalize: false
249+
kl_loss_fn_args:
250+
kl_coef: 0.0
209251
```
210252

211253
## Citation

examples/rec_gsm8k/gsm8k.yaml

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,15 @@ algorithm:
1616
epsilon_high: 0.2
1717
clip_mode: "one-side"
1818
weight: "none"
19-
temp: 1.0
20-
regularizer: "none"
21-
regularizer_coef: 0.0
2219
advantage_fn_args:
2320
std_normalize: false
21+
kl_loss_fn_args:
22+
kl_coef: 0.0
2423
cluster:
2524
node_num: 1
2625
gpu_per_node: 8
2726
buffer:
28-
total_steps: 100
27+
total_steps: 160
2928
batch_size: 96
3029
explorer_input:
3130
taskset:

trinity/algorithm/algorithm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -434,8 +434,8 @@ def default_config(cls) -> Dict:
434434
"policy_loss_fn": "rec",
435435
"advantage_fn": "rec",
436436
"kl_penalty_fn": "none",
437-
"kl_loss_fn": "none",
438-
"entropy_loss_fn": "none",
437+
"kl_loss_fn": "k2",
438+
"entropy_loss_fn": "default",
439439
}
440440

441441

trinity/algorithm/policy_loss_fn/gspo_policy_loss.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,15 @@ def __init__(
3030
if _clip_range_high is None:
3131
raise ValueError("Either clip_range or clip_range_high must be specified.")
3232
self.clip_range_high = _clip_range_high
33+
34+
if loss_agg_mode != "seq-mean-token-mean":
35+
from trinity.utils.log import get_logger
36+
37+
logger = get_logger(__name__)
38+
logger.warning(
39+
f"The original GSPO paper requires loss_agg_mode to be 'seq-mean-token-mean', but the current setting is '{loss_agg_mode}'."
40+
)
41+
# loss_agg_mode = "seq-mean-token-mean"
3342
self.loss_agg_mode = loss_agg_mode
3443

3544
def __call__( # type: ignore

trinity/algorithm/policy_loss_fn/rec_policy_loss.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import torch
77

88
from trinity.algorithm.policy_loss_fn.policy_loss_fn import PolicyLossFn
9-
from trinity.algorithm.utils import masked_mean
9+
from trinity.algorithm.utils import aggregate_loss, masked_mean
1010

1111

1212
class RECPolicyLossFn(PolicyLossFn):
@@ -41,13 +41,15 @@ def __init__(
4141
assert self.clip_mode in [
4242
"none",
4343
"one-side",
44+
"gspo-one-side",
4445
"two-side",
4546
"ring",
4647
], f"Invalid clip_mode: {self.clip_mode}"
4748
self.weight = weight
4849
assert self.weight in [
4950
"none",
5051
"importance_sampling",
52+
"gspo_importance_sampling",
5153
"advantage",
5254
], f"Invalid weight: {self.weight}"
5355

@@ -71,8 +73,8 @@ def __call__( # type: ignore
7173
**kwargs,
7274
) -> Tuple[torch.Tensor, Dict]:
7375
"""Calculate REC loss."""
74-
# token-wise
75-
ratio = torch.exp(logprob - old_logprob).detach()
76+
77+
ratio = torch.exp(logprob - old_logprob).detach() # token-wise prob ratio
7678

7779
# clipping
7880
if self.clip_mode == "two-side":
@@ -81,6 +83,16 @@ def __call__( # type: ignore
8183
is_in_range = (ratio <= (1 + self.epsilon_high)) * (advantages >= 0) + (
8284
advantages <= 0
8385
) * (ratio >= (1 - self.epsilon_low))
86+
elif self.clip_mode == "gspo-one-side":
87+
mean_log_prob_diff = masked_mean(
88+
logprob - old_logprob, action_mask, axis=-1
89+
).detach() # [batch_size]
90+
normalized_seq_ratio = torch.exp(mean_log_prob_diff).unsqueeze(-1) # [batch_size, 1]
91+
is_in_range = (normalized_seq_ratio <= (1 + self.epsilon_high)) * (advantages >= 0) + (
92+
normalized_seq_ratio >= (1 - self.epsilon_low)
93+
) * (
94+
advantages <= 0
95+
) # [batch_size, seq_len]
8496
elif self.clip_mode == "ring":
8597
is_in_range = (
8698
(ratio >= (1 - self.epsilon_low)) * (ratio <= (1 + self.epsilon_high))
@@ -93,6 +105,8 @@ def __call__( # type: ignore
93105

94106
if self.weight == "importance_sampling":
95107
advantages = advantages * ratio # importance sampling
108+
elif self.weight == "gspo_importance_sampling":
109+
advantages = advantages * normalized_seq_ratio
96110
elif self.weight == "advantage":
97111
weight = torch.exp(advantages / self.temp)
98112
advantages = advantages * weight # advantage weighting (unnormalized version)
@@ -107,7 +121,15 @@ def __call__( # type: ignore
107121
regularizer_losses = self.regularizer_coef * (logprob - old_logprob).square()
108122
pg_losses = pg_losses + regularizer_losses
109123

110-
pg_loss = masked_mean(pg_losses, action_mask)
124+
if self.clip_mode == "gspo-one-side":
125+
# [EXPERIMENTAL] specialized for gspo-style rec variant for now
126+
pg_loss = aggregate_loss(
127+
values=pg_losses,
128+
mask=action_mask,
129+
loss_agg_mode="seq-mean-token-mean",
130+
)
131+
else:
132+
pg_loss = masked_mean(pg_losses, action_mask)
111133

112134
pg_clipfrac = masked_mean(is_clipped_mask.float(), action_mask)
113135
metrics = {

0 commit comments

Comments
 (0)