Skip to content

Commit 5c62657

Browse files
yfwSahilJain314
andauthored
feat: Importance sampling trick (#174)
Signed-off-by: Yi-Fu Wu <yifu.wu@gmail.com> Co-authored-by: Sahil Jain <48468750+SahilJain314@users.noreply.github.com>
1 parent deaece6 commit 5c62657

File tree

5 files changed

+295
-9
lines changed

5 files changed

+295
-9
lines changed

docs/adding-new-models.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ In on-policy RL, we sample tokens (actions) from the latest version of the polic
88

99
As an example, we would see errors in naive KL estimation:
1010

11-
$$\text{KL} = E_{x \sim \pi}[\pi(x) - \pi_{\text{ref}}(x)]$$
11+
$$\text{KL} = E_{x \sim \pi}[\pi(x) - \pi_{\text{ref}}(x)]$$
1212

1313
When summed/integrated, replacing the $x \sim \pi$ with $x \sim \pi_{\text{wrong}}$ leads to an error of:
1414

@@ -17,12 +17,12 @@ $$\sum_{x} \left( \pi(x) - \pi_{\text{ref}}(x) \right) \left( \pi_{\text{wrong}}
1717
So, to verify correctness, we calculate
1818

1919
$$
20-
\frac{1}{n}\sum_{i=1}^{n\text{(tokens)}}\exp\left(\left\|\text{logprobs-train-fwk}_i - \text{logprobs-sampling-fwk}_i\right\|\right)
20+
\frac{1}{n}\sum_{i=1}^{n\text{(tokens)}}\exp\left(\left\|\text{logprobs-train-fwk}_i - \text{logprobs-inference-fwk}_i\right\|\right)
2121
$$
2222

23-
where samples are drawn as $x \sim \pi_{\text{sampling-framework}}$
23+
where samples are drawn as $x \sim \pi_{\text{inference-framework}}$
2424

25-
As a measure of multiplicative probability error for sampled tokens. Note that this is not exhaustive (the sampling framework could lack distribution support and we wouldn't catch it here, as $x \sim \pi_{\text{sampling-framework}}$). To get a much stricter guarantee on correctness, you should run this metric twice and average the results, where in the second run, you sample $x \sim \pi_{\text{training-framework}}$. In practice, we use just the former in our tests and find it sufficient.
25+
As a measure of multiplicative probability error for sampled tokens. Note that this is not exhaustive (the inference framework could lack distribution support and we wouldn't catch it here, as $x \sim \pi_{\text{inference-framework}}$). To get a much stricter guarantee on correctness, you should run this metric twice and average the results, where in the second run, you sample $x \sim \pi_{\text{training-framework}}$. In practice, we use just the former in our tests and find it sufficient.
2626

2727
## Understanding Discrepancies Between Backends
2828

docs/guides/grpo.md

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,13 @@ If not specified, `config` will default to [examples/configs/grpo.yaml](../../ex
1616

1717
## Now, for the details:
1818

19-
In this guide, we'll walk through we handle
19+
In this guide, we'll walk through how we handle
2020

2121
* Data
2222
* Model training
2323
* Fast generation
2424
* Overall Resource Flow
25+
* Loss
2526

2627
### Data
2728

@@ -108,3 +109,60 @@ This Policy object holds a [RayWorkerGroup](../../nemo_reinforcer/distributed/wo
108109
We support vLLM through the [VllmGeneration](../../nemo_reinforcer/models/generation/vllm.py) class right now.
109110

110111
The function [grpo_train](../../nemo_reinforcer/algorithms/grpo.py) contains the core GRPO training loop.
112+
113+
### Loss
114+
We use the [ClippedPGLossFn](../../nemo_reinforcer/algorithms/loss_functions.py) to calculate the loss for GRPO. Formally,
115+
116+
$$
117+
L(\theta) = E_{x \sim \pi_{\theta_{\text{old}}}} \Big[ \min \Big(\frac{\pi_\theta(x)}{\pi_{\theta_{\text{old}}}(x)}A_t, \text{clip} \big( \frac{\pi_\theta(x)}{\pi_{\theta_{\text{old}}}(x)}, 1 - \varepsilon, 1 + \varepsilon \big) A_t \Big) \Big] - \beta D_{\text{KL}} (\pi_\theta \| \pi_\text{ref})
118+
$$
119+
120+
where:
121+
122+
- $\pi_\theta$ is the policy model we are currently optimizing
123+
- $\pi_{\theta_{\text{old}}}$ is the previous policy model (from the beginning of this step)
124+
- $A_t$ is the advantage estimate
125+
- $\varepsilon$ is a clipping hyperparameter
126+
- $\beta$ is the KL penalty coefficient
127+
- $\pi_{\text{ref}}$ is the reference policy
128+
129+
#### Improvements to the GRPO loss formulation for stability and accuracy
130+
131+
#### On-Policy KL Approximation
132+
133+
In practice, we calculate the KL divergence using the estimator from Schulman 2020 (http://joschu.net/blog/kl-approx.html), which is unbiased and guaranteed to be positive.
134+
135+
$$
136+
D_{\text{KL}} (\pi_\theta || \pi_\text{ref}) \approx E_{x \sim \pi_{\theta}} \Big[ \frac{\pi_\text{ref}(x)}{\pi_\theta(x)} - \log \frac{\pi_\text{ref}(x)}{\pi_\theta(x)} - 1 \Big]
137+
$$
138+
139+
Note that the loss function above samples from $\pi_{\theta_{\text{old}}}$ instead of $\pi_\theta$, meaning that the KL approximation is off-policy if we use samples from $\pi_{\theta_{\text{old}}}$. This is the default formulation used in the [original GRPO paper](https://arxiv.org/abs/2402.03300). In order to use an _on-policy_ KL approximation while sampling from $\pi_{\theta_{\text{old}}}$, we can incorporate importance weights:
140+
141+
$$
142+
\begin{align*}
143+
D_{\text{KL}} (\pi_\theta || \pi_\text{ref}) &\approx E_{x \sim \pi_{\theta}} \Big[ \frac{\pi_\text{ref}(x)}{\pi_\theta(x)} - \log \frac{\pi_\text{ref}(x)}{\pi_\theta(x)} - 1 \Big] \\
144+
&= \sum_x \pi_{\theta}(x) \Big[ \frac{\pi_\text{ref}(x)}{\pi_\theta(x)} - \log \frac{\pi_\text{ref}(x)}{\pi_\theta(x)} - 1 \Big] \\
145+
&= \sum_x \pi_{\theta_{\text{old}}}(x) \frac{\pi_{\theta}(x)}{\pi_{\theta_{\text{old}}}(x)} \Big[ \frac{\pi_\text{ref}(x)}{\pi_\theta(x)} - \log \frac{\pi_\text{ref}(x)}{\pi_\theta(x)} - 1 \Big] \\
146+
&= E_{x \sim \pi_{\theta_\text{old}}} \frac{\pi_{\theta}(x)}{\pi_{\theta_{\text{old}}}(x)} \Big[ \frac{\pi_\text{ref}(x)}{\pi_\theta(x)} - \log \frac{\pi_\text{ref}(x)}{\pi_\theta(x)} - 1 \Big] \\
147+
\end{align*}
148+
$$
149+
150+
To enable the on-policy KL approximation, set the config `use_on_policy_kl_approximation=True` in the `ClippedPGLossConfig`. By default, we set this config to False to align with standard GRPO.
151+
152+
153+
#### Importance Sampling Correction
154+
The policy we use to draw samples, $\pi_{\theta_{\text{old}}}$, is used in both the inference framework and the training framework. To account for this distinction, we refer to the inference framework policy as $\pi_{\text{inference}}$ and the training framework policy as $\pi_{\text{training}}$. As noted in [Adding New Models](../adding_new_models.md#understanding-discrepancies-between-backends), it is possible for the token probabilities from $\pi_{\text{training}}$ and $\pi_{\text{inference}}$ to have discrepancies (from numerics, precision differences, bugs, etc.), leading to off-policy samples. We can correct for this by introducing importance weights between $\pi_{\text{training}}$ and $\pi_{\text{inference}}$ to the first term of the loss function.
155+
156+
Let $f_\theta(x) = \min \Big(\frac{\pi_\theta(x)}{\pi_{\theta_{\text{old}}}(x)}A_t, \text{clip} \big( \frac{\pi_\theta(x)}{\pi_{\theta_{\text{old}}}(x)}, 1 - \varepsilon, 1 + \varepsilon \big) A_t \Big)$ represent the first term of loss function. Then,
157+
158+
$$
159+
\begin{align*}
160+
E_{x \sim \pi_\text{training}} f_\theta(x) &= \sum_x \pi_\text{training}(x) f_\theta(x) \\
161+
&= \sum_x \pi_\text{inference}(x) \frac{\pi_\text{training}(x)}{\pi_\text{inference}(x)} f_\theta(x) \\
162+
&= E_{x \sim \pi_\text{inference}} \frac{\pi_\text{training}(x)}{\pi_\text{inference}(x)} f_\theta(x)
163+
\end{align*}
164+
$$
165+
166+
By multiplying the first term of the loss function by the importance weights $\frac{\pi_\text{training}(x)}{\pi_\text{inference}(x)}$, we can correct for the distribution mismatch between $\pi_{\text{training}}$ and $\pi_{\text{inference}}$ while still sampling from $\pi_{\text{inference}}$.
167+
168+
To enable the importance sampling correction, set the config `use_importance_sampling_correction=True` in the `ClippedPGLossConfig`. By default, we set this config to False to align with standard GRPO.

examples/configs/grpo_math_1B.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ loss_fn:
1515
reference_policy_kl_penalty: 0.01
1616
ratio_eps_min: 0.2
1717
ratio_eps_max: 0.2
18+
# (default off) loss formulation improvements (docs/guides/grpo.md#loss)
19+
use_on_policy_kl_approximation: false
20+
use_importance_sampling_correction: false
1821

1922
checkpointing:
2023
enabled: true

nemo_reinforcer/algorithms/loss_functions.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ class ClippedPGLossConfig(TypedDict):
3131
reference_policy_kl_penalty: float
3232
ratio_eps_min: float
3333
ratio_eps_max: float
34+
use_on_policy_kl_approximation: bool
35+
use_importance_sampling_correction: bool
3436

3537

3638
class ClippedPGLossDataDict(TypedDict):
@@ -80,6 +82,10 @@ def __init__(self, cfg: ClippedPGLossConfig):
8082
self.ratio_eps_max = cfg["ratio_eps_max"]
8183
self.reference_policy_kl_penalty = cfg["reference_policy_kl_penalty"]
8284
self.disable_ppo_ratio = cfg.get("disable_ppo_ratio", False)
85+
self.use_on_policy_kl_approximation = cfg["use_on_policy_kl_approximation"]
86+
self.use_importance_sampling_correction = cfg[
87+
"use_importance_sampling_correction"
88+
]
8389

8490
def __call__(
8591
self,
@@ -122,9 +128,23 @@ def __call__(
122128

123129
# Calculate KL regularization.
124130
if self.reference_policy_kl_penalty != 0:
125-
kl = self.reference_policy_kl_penalty * calculate_kl_penalty_joschu2020(
126-
logprobs_policy=curr_logprobs,
127-
logprobs_reference=reference_policy_logprobs,
131+
if self.use_on_policy_kl_approximation:
132+
# See: docs/guides/grpo.md#on-policy-kl-approximation
133+
kl_importance_weights = torch.exp(
134+
curr_logprobs - generation_logprobs
135+
).detach()
136+
kl_importance_weights = torch.nan_to_num(
137+
kl_importance_weights, nan=0.0, posinf=0.0, neginf=0.0
138+
)
139+
else:
140+
kl_importance_weights = torch.ones_like(curr_logprobs)
141+
kl = (
142+
kl_importance_weights
143+
* self.reference_policy_kl_penalty
144+
* calculate_kl_penalty_joschu2020(
145+
logprobs_policy=curr_logprobs,
146+
logprobs_reference=reference_policy_logprobs,
147+
)
128148
)
129149
kl = masked_mean(kl, mask)
130150
else:
@@ -143,7 +163,17 @@ def __call__(
143163
loss1 = -advantages * ratios
144164
loss2 = -advantages * ratios_clamped
145165

146-
actor_loss = masked_mean(torch.max(loss1, loss2), mask)
166+
if self.use_importance_sampling_correction:
167+
# See: docs/guides/grpo.md#importance-sampling-correction
168+
actor_importance_weights = torch.exp(prev_logprobs - generation_logprobs)
169+
actor_importance_weights = torch.nan_to_num(
170+
actor_importance_weights, nan=0.0, posinf=0.0, neginf=0.0
171+
)
172+
else:
173+
actor_importance_weights = torch.ones_like(prev_logprobs)
174+
actor_loss = masked_mean(
175+
actor_importance_weights * torch.max(loss1, loss2), mask
176+
)
147177
loss = actor_loss + kl
148178
with torch.no_grad():
149179
probs_ratio = masked_mean(ratios.detach(), mask).item()

0 commit comments

Comments
 (0)