Skip to content

Commit 406d406

Browse files
authored
feat(grpo_trainer.py): Variational Sequence-Level Soft Policy Optimization (VESPO) (#5199)
1 parent d0ac7ef commit 406d406

File tree

4 files changed

+166
-11
lines changed

4 files changed

+166
-11
lines changed

docs/source/paper_index.md

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,24 @@ trainer = GRPOTrainer(
165165
)
166166
```
167167

168+
### INTELLECT-2: A Reasoning Model Trained Through Globally Decentralized Reinforcement Learning
169+
170+
**📜 Paper**: https://huggingface.co/papers/2505.07291
171+
172+
INTELLECT-2 is the first globally distributed reinforcement learning training run of a 32 billion parameter language model using fully asynchronous RL across a dynamic, heterogeneous swarm of permissionless compute contributors. The authors propose modifications to the standard GRPO training recipe, including two-sided GRPO clipping for increased training stability. To reproduce the paper's setting, use this configuration:
173+
174+
```python
175+
from trl import GRPOConfig
176+
177+
training_args = GRPOConfig(
178+
delta=4, # δ in section 4.1 of the paper
179+
epsilon=0.2, # ε in section 4.1 of the paper
180+
beta=0.001, # KL divergence coefficient in section 4.1 of the paper
181+
num_generations=16, # responses per prompt in section 4.1 of the paper
182+
learning_rate=3e-7, # section 4.1 of the paper
183+
)
184+
```
185+
168186
### Beyond the 80/20 Rule: High-Entropy Minority Tokens Drive Effective Reinforcement Learning for LLM Reasoning
169187

170188
**📜 Paper**: https://huggingface.co/papers/2506.01939
@@ -573,21 +591,30 @@ training_args = GRPOConfig(
573591
)
574592
```
575593

576-
### INTELLECT-2: A Reasoning Model Trained Through Globally Decentralized Reinforcement Learning
594+
### VESPO: Variational Sequence-Level Soft Policy Optimization for Stable Off-Policy LLM Training
577595

578-
**📜 Paper**: https://huggingface.co/papers/2505.07291
596+
**📜 Paper**: https://huggingface.co/papers/2602.10693
579597

580-
INTELLECT-2 is the first globally distributed reinforcement learning training run of a 32 billion parameter language model using fully asynchronous RL across a dynamic, heterogeneous swarm of permissionless compute contributors. The authors propose modifications to the standard GRPO training recipe, including two-sided GRPO clipping for increased training stability. To reproduce the paper's setting, use this configuration:
598+
VESPO addresses training instability in off-policy RL caused by policy staleness, asynchronous updates, and train-inference mismatches. Rather than relying on heuristic token-level clipping (GRPO) or sequence-length normalization (GSPO), VESPO derives a principled reshaping kernel from a variational framework. In practice, this yields a smooth, asymmetric Gamma weighting function that gracefully suppresses extreme sequence-level importance weights without introducing length bias.
599+
600+
$$
601+
\mathcal{L}_{\text{VESPO}}(\theta) = - \mathbb{E}_{\tau \sim \mu} \left[ \underbrace{W(\tau)^{k} \cdot \exp\left(\lambda
602+
(1 - W(\tau))\right)}_{\phi(W) \text{ detached }} \cdot \mathcal{A}(\tau) \cdot \log \pi_\theta(\tau) \right]
603+
$$
604+
605+
with \\( W(\tau) = \frac{\pi_\theta(\tau)}{\mu(\tau)} \\) the sequence level importance ratio, and \\( \phi(W) \\) is detached from the computation graph to serve as a gradient scaling coefficient.
581606

582607
```python
583608
from trl import GRPOConfig
584609

585610
training_args = GRPOConfig(
586-
delta=4, # δ in section 4.1 of the paper
587-
epsilon=0.2, # ε in section 4.1 of the paper
588-
beta=0.001, # KL divergence coefficient in section 4.1 of the paper
589-
num_generations=16, # responses per prompt in section 4.1 of the paper
590-
learning_rate=3e-7, # section 4.1 of the paper
611+
loss_type="vespo",
612+
use_vllm=True, # or False if not using any token-level `vllm_importance_sampling_correction` methods
613+
vllm_importance_sampling_mode="token_truncate", # default correction mode for VESPO, `token_mask` also supported
614+
vespo_k_pos=2.0, # power exponent (c1 in paper Section 3.4) for positive advantages
615+
vespo_lambda_pos=3.0, # decay factor (c2 in paper Section 3.4) for positive advantages
616+
vespo_k_neg=3.0, # power exponent (c1 in paper Section 3.4) for negative advantages
617+
vespo_lambda_neg=2.0, # decay factor (c2 in paper Section 3.4) for negative advantages
591618
)
592619
```
593620

tests/test_grpo_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ def test_training(self, config_name):
278278
new_param = trainer.model.get_parameter(n)
279279
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
280280

281-
@pytest.mark.parametrize("loss_type", ["bnpo", "dr_grpo", "dapo", "cispo", "sapo", "luspo"])
281+
@pytest.mark.parametrize("loss_type", ["bnpo", "dr_grpo", "dapo", "cispo", "sapo", "luspo", "vespo"])
282282
def test_training_loss_types(self, loss_type):
283283
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
284284

trl/trainer/grpo_config.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,19 @@ class GRPOConfig(_BaseConfig):
186186
sapo_temperature_pos (`float`, *optional*, defaults to `1.0`):
187187
Temperature for tokens with positive advantage scores used in the `sapo` loss function. This parameter is
188188
introduced in the [Soft Adaptive Policy Optimization paper](https://huggingface.co/papers/2511.20347).
189+
vespo_k_pos (`float`, *optional*, defaults to `2.0`):
190+
k parameter for positive advantages, it is the power exponent in the VESPO loss. Controls how aggressively
191+
we down-weight samples with low importance weights (when the importance sampling ratio < 1).
192+
vespo_lambda_pos (`float`, *optional*, defaults to `3.0`):
193+
lambda parameter for positive advantages, it is the decay factor in the VESPO loss. Controls how
194+
aggressively we down-weight samples with high importance weights (when the importance sampling ratio > 1).
195+
vespo_k_neg (`float`, *optional*, defaults to `3.0`):
196+
k parameter for negative advantages, it is the power exponent in the VESPO loss. Controls how aggressively
197+
we down-weight samples with low importance weights (when the importance sampling ratio < 1).
198+
vespo_lambda_neg (`float`, *optional*, defaults to `2.0`):
199+
lambda parameter for negative advantages, it is the exponential decay factor in the VESPO loss. Controls
200+
how aggressively we down-weight samples with high importance weights (when the importance sampling ratio >
201+
1).
189202
importance_sampling_level (`str`, *optional*, defaults to `"token"`):
190203
Controls whether importance sampling ratios are computed at the `"token"` or `"sequence"` level. `"token"`
191204
keeps the raw per-token log-probability ratios (one weight per token). `"sequence"` averages the
@@ -243,6 +256,9 @@ class GRPOConfig(_BaseConfig):
243256
sequence's loss by its length. This is a modification of GSPO and requires
244257
`importance_sampling_level="sequence"`. Introduced in the [LUSPO
245258
paper](https://huggingface.co/papers/2602.05261).
259+
- `"vespo"`: Variational Sequence-Level Soft Policy Optimization. Replaces hard clipping with a smooth,
260+
asymmetric Gamma weighting function applied directly to sequence-level importance weights. Introduced in
261+
the [VESPO paper](https://huggingface.co/papers/2602.10693).
246262
mask_truncated_completions (`bool`, *optional*, defaults to `False`):
247263
When enabled, truncated completions are excluded from the loss calculation, preventing them from being
248264
incorrectly penalized and introducing noise during training. According to the
@@ -619,6 +635,36 @@ class GRPOConfig(_BaseConfig):
619635
"paper](https://huggingface.co/papers/2511.20347)."
620636
},
621637
)
638+
vespo_k_pos: float = field(
639+
default=2.0,
640+
metadata={
641+
"help": "k parameter for positive advantages, it is the power exponent in the VESPO loss. Controls how "
642+
"aggressively we down-weight samples with low importance weights (when the importance sampling ratio < 1)."
643+
},
644+
)
645+
vespo_lambda_pos: float = field(
646+
default=3.0,
647+
metadata={
648+
"help": "lambda parameter for positive advantages, it is the decay factor in the VESPO loss. Controls "
649+
"how aggressively we down-weight samples with high importance weights (when the importance sampling ratio "
650+
"> 1)."
651+
},
652+
)
653+
vespo_k_neg: float = field(
654+
default=3.0,
655+
metadata={
656+
"help": "k parameter for negative advantages, it is the power exponent in the VESPO loss. Controls how "
657+
"aggressively we down-weight samples with low importance weights (when the importance sampling ratio < 1)."
658+
},
659+
)
660+
vespo_lambda_neg: float = field(
661+
default=2.0,
662+
metadata={
663+
"help": "lambda parameter for negative advantages, it is the exponential decay factor in the VESPO loss. "
664+
"Controls how aggressively we down-weight samples with high importance weights (when the importance "
665+
"sampling ratio > 1)."
666+
},
667+
)
622668
importance_sampling_level: str = field(
623669
default="token",
624670
metadata={
@@ -690,6 +736,9 @@ class GRPOConfig(_BaseConfig):
690736
"sequence's loss by its length. This is a modification of GSPO and requires "
691737
"`importance_sampling_level='sequence'`. Introduced in the [LUSPO "
692738
"paper](https://huggingface.co/papers/2602.05261)."
739+
"'vespo': Variational Sequence-Level Soft Policy Optimization. Replaces hard clipping with a smooth, "
740+
"asymmetric Gamma weighting function applied directly to sequence-level importance weights. Introduced in "
741+
"the [VESPO paper](https://huggingface.co/papers/2602.10693)."
693742
},
694743
)
695744
mask_truncated_completions: bool = field(

trl/trainer/grpo_trainer.py

Lines changed: 81 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import copy
1818
import importlib.resources as pkg_resources
1919
import inspect
20+
import math
2021
import os
2122
import sys
2223
import textwrap
@@ -579,6 +580,19 @@ def __init__(
579580
"paper's setup."
580581
)
581582

583+
if args.loss_type == "vespo" and args.importance_sampling_level != "token":
584+
logger.warning(
585+
"VESPO computes sequence-level importance weights internally. `importance_sampling_level` should be "
586+
"set to `'token'` (the default)."
587+
)
588+
589+
if self.loss_type == "vespo" and self.use_vllm and self.vllm_importance_sampling_correction:
590+
if self.vllm_importance_sampling_mode not in ["token_truncate", "token_mask"]:
591+
raise ValueError(
592+
f"VESPO loss requires `vllm_importance_sampling_mode` to be either 'token_truncate' or "
593+
f"'token_mask'. Got: {self.vllm_importance_sampling_mode}."
594+
)
595+
582596
# Multi-step
583597
self.num_iterations = args.num_iterations # = 𝜇 in the GRPO paper
584598
self.epsilon_low = args.epsilon
@@ -2099,6 +2113,56 @@ def get_off_policy_mask(
20992113
is_low_kl = avg_seq_kl <= off_policy_threshold
21002114
return (is_pos_adv | is_low_kl).to(dtype=mask.dtype) # (B, 1)
21012115

2116+
@staticmethod
2117+
@torch.no_grad()
2118+
def get_gamma_weights(
2119+
advantages: torch.Tensor,
2120+
log_ratio_per_token: torch.Tensor,
2121+
mask: torch.Tensor,
2122+
importance_sampling_ratio: torch.Tensor | None, # (B, T)
2123+
k_pos: float = 2.0,
2124+
lambda_pos: float = 3.0,
2125+
k_neg: float = 3.0,
2126+
lambda_neg: float = 2.0,
2127+
) -> torch.Tensor:
2128+
"""
2129+
Computes the Gamma weights for the VESPO loss. For reference:
2130+
φ(w) = e^λ × w^k × e^{-λw} is the gamma weighting (normalized so φ(1)=1)
2131+
with w = sequence-level importance sampling ratio
2132+
note: we will compute φ(w) in log space
2133+
2134+
φ(w) is detached via @torch.no_grad(), only acts as gradient scaling coefficient
2135+
2136+
VESPO loss = -φ(w) × A × log_prob, gradient naturally gives φ(w) × A × ∇log π
2137+
"""
2138+
# reducing clamp range directly to log(1e-8) ~ -18.42, to avoid recomputing log_w=log(w.clamp(min=1e-8)) later
2139+
# This is solely for matching truthfully the original implementation, otherwise keeping -20 could be fine.
2140+
lower_clamp = math.log(1e-8)
2141+
2142+
# Sequence-level log ratio Σ log(π_θ/π_old) (not a mean like for `log_importance_weights`)
2143+
log_ratio_clamped = torch.clamp(log_ratio_per_token, -20.0, 20.0)
2144+
seq_log_ratio = torch.sum(log_ratio_clamped * mask, dim=-1, keepdim=True) # (B, 1)
2145+
2146+
# Apply token-level TIS or MIS correction (in log space)
2147+
if importance_sampling_ratio is not None:
2148+
log_is_ratio = torch.clamp(torch.log(importance_sampling_ratio), lower_clamp, 20.0)
2149+
# log(w) = log(π_θ/π_old) + log(π_old/π_sampler)
2150+
seq_log_ratio += torch.sum(log_is_ratio, dim=-1, keepdim=True)
2151+
2152+
log_w_seq = torch.clamp(seq_log_ratio, lower_clamp, 20.0)
2153+
w_seq = torch.exp(log_w_seq)
2154+
2155+
# compute k and lambda based on advantage sign
2156+
is_nonneg_adv = advantages >= 0
2157+
k_seq = torch.where(is_nonneg_adv, k_pos, k_neg)
2158+
lambda_seq = torch.where(is_nonneg_adv, lambda_pos, lambda_neg).clamp(min=1e-4)
2159+
2160+
# log(φ(w)) = λ + k × log(w) - λ × w
2161+
log_phi = lambda_seq + k_seq * log_w_seq - lambda_seq * w_seq
2162+
phi_seq = torch.exp(log_phi).nan_to_num(nan=0.0, posinf=0.0, neginf=0.0)
2163+
2164+
return phi_seq # (B, 1)
2165+
21022166
def _compute_loss(self, model, inputs):
21032167
# Compute the per-token log probabilities for the model
21042168
prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
@@ -2200,6 +2264,18 @@ def _compute_loss(self, model, inputs):
22002264
temperatures = torch.where(advantages > 0, self.args.sapo_temperature_pos, self.args.sapo_temperature_neg)
22012265
soft_coef_1 = torch.sigmoid(temperatures * (coef_1 - 1)) * 4 / temperatures
22022266
per_token_loss = -soft_coef_1 * advantages
2267+
elif self.loss_type == "vespo":
2268+
phi_seq = self.get_gamma_weights(
2269+
advantages=advantages,
2270+
log_ratio_per_token=log_ratio,
2271+
mask=mask,
2272+
importance_sampling_ratio=inputs.get("importance_sampling_ratio"),
2273+
k_pos=self.args.vespo_k_pos,
2274+
lambda_pos=self.args.vespo_lambda_pos,
2275+
k_neg=self.args.vespo_k_neg,
2276+
lambda_neg=self.args.vespo_lambda_neg,
2277+
)
2278+
per_token_loss = -phi_seq * advantages * per_token_logps
22032279
else:
22042280
raise ValueError(f"Unknown loss type: {self.loss_type}")
22052281

@@ -2209,7 +2285,7 @@ def _compute_loss(self, model, inputs):
22092285
if entropy_mask is not None:
22102286
per_token_loss = per_token_loss * entropy_mask
22112287

2212-
if self.use_vllm and self.vllm_importance_sampling_correction:
2288+
if self.use_vllm and self.vllm_importance_sampling_correction and self.loss_type != "vespo":
22132289
per_token_loss = per_token_loss * inputs["importance_sampling_ratio"]
22142290

22152291
if self.beta != 0.0:
@@ -2228,7 +2304,7 @@ def _compute_loss(self, model, inputs):
22282304
loss = (per_token_loss * mask).sum() / (per_token_loss.size(0) * self.max_completion_length)
22292305
normalizer = self.current_gradient_accumulation_steps if mode == "train" else 1.0 # no accum in eval
22302306
loss = loss / normalizer
2231-
elif self.loss_type in ["cispo", "dapo"]:
2307+
elif self.loss_type in ["cispo", "dapo", "vespo"]:
22322308
normalizer = inputs["num_items_in_batch"] / self.accelerator.num_processes
22332309
loss = (per_token_loss * mask).sum() / normalizer
22342310
elif self.loss_type == "luspo":
@@ -2278,6 +2354,9 @@ def masked_batch_mean(x):
22782354
cispo_clip_ratio = masked_batch_mean(is_cispo_clipped.float())
22792355
gathered_cispo_clip_ratio = self.accelerator.gather(cispo_clip_ratio)
22802356
self._metrics[mode]["cispo_clip_ratio"].append(gathered_cispo_clip_ratio.nanmean().item())
2357+
elif self.loss_type == "vespo":
2358+
gathered_phi_seq = self.accelerator.gather(phi_seq)
2359+
self._metrics[mode]["vespo/phi_seq_mean"].append(gathered_phi_seq.nanmean().item())
22812360

22822361
return loss
22832362

0 commit comments

Comments
 (0)