Skip to content

Commit 46e2cd4

Browse files
Refactoring training inference importance sampling with seqeunce/geometry level (#429)
Co-authored-by: Jiajun Li <[email protected]>
1 parent 4c54df7 commit 46e2cd4

File tree

5 files changed

+557
-10
lines changed

5 files changed

+557
-10
lines changed
Lines changed: 328 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,328 @@
1+
from typing import Any, Dict, Optional, Tuple
2+
3+
import torch
4+
5+
from slime.backends.megatron_utils.cp_utils import all_gather_with_cp, slice_log_prob_with_cp
6+
7+
8+
def masked_sum(x: torch.Tensor, loss_mask: torch.Tensor, expand: bool = False) -> torch.Tensor:
9+
result = (x * loss_mask).sum()
10+
return result.expand_as(x) if expand else result
11+
12+
13+
def masked_mean(x: torch.Tensor, loss_mask: torch.Tensor, expand: bool = False) -> torch.Tensor:
14+
result = masked_sum(x, loss_mask) / torch.clamp_min(loss_mask.sum(), 1)
15+
return result.expand_as(x) if expand else result
16+
17+
18+
def metrics_append(metrics: Dict[str, list[torch.Tensor]], key: str, value: torch.Tensor) -> None:
19+
"""
20+
21+
Every metrics-dict value is a list of 1D tensor, i.e., [torch.Tensor] with shapes exactly the same as log_probs.
22+
23+
All metrics will be aggregated and averaged by `sum_of_sample_mean` and divided by DP size automatically
24+
- If calculate_per_token_loss=False (default), the final results will first be averaged in each sequence,
25+
then across all the sequences in the global batch.
26+
- If calculate_per_token_loss=True, the final results will be the mean of all the tokens in the global batch.
27+
28+
No need to specifically handle loss_mask, sum_of_sample_mean automatically ignores statistics where loss_mask = 0.
29+
30+
e.g.
31+
For token-level metric:
32+
value = [
33+
[0.1, 0.2],
34+
[0.1, 0.2, 0.3, 0.4, 0.5],
35+
[0.6]
36+
]
37+
When calculate_per_token_loss = False (default):
38+
result = (0.1 + 0.2) / 2 + (0.1 + 0.2 + 0.3 + 0.4 + 0.5) / 5 + (0.6) / 1 = 0.15 + 0.3 + 0.6 = 1.05 / 3 = 0.35
39+
When calculate_per_token_loss = True:
40+
result = (0.1 + 0.2 + 0.1 + 0.2 + 0.3 + 0.4 + 0.5 + 0.6) / 8 = 2.4 / 8 = 0.3
41+
For sequence-level metric:
42+
original sequence lengths = [2, 5, 2]
43+
We should expand the metrics to the length of each sequence:
44+
value = [
45+
[2, 2],
46+
[5, 5, 5, 5, 5],
47+
[1, 1]
48+
]
49+
When calculate_per_token_loss = False (default):
50+
result = (2 + 2) / 2 + (5 + 5 + 5 + 5 + 5) / 5 + (1 + 1) / 2 = 2 + 5 + 1 = 8 / 3 = 2.6665
51+
Note that for sequence-level, calculating token-level loss is invalid; thus, calculate_per_token_loss should always be False.
52+
"""
53+
if key not in metrics:
54+
metrics[key] = []
55+
metrics[key].append(value.clone().detach())
56+
57+
58+
def calculate_veto_mask(
59+
log_ratio: torch.Tensor,
60+
loss_mask: torch.Tensor,
61+
veto_threshold: Optional[float],
62+
metrics: Dict[str, list[torch.Tensor]],
63+
) -> torch.Tensor:
64+
if veto_threshold is None:
65+
return torch.ones_like(log_ratio)
66+
log_veto_threshold = torch.log(torch.tensor(veto_threshold, device=log_ratio.device))
67+
# For each sequence, if it has any catastrophic tokens, return 0 for the sequence
68+
catastrophic_tokens = ((log_ratio < log_veto_threshold)) & loss_mask.bool()
69+
has_catastrophic = catastrophic_tokens.any()
70+
veto_mask = (~has_catastrophic).float().expand_as(log_ratio)
71+
72+
metrics_append(metrics, "catastrophic_token_fraction", catastrophic_tokens.int())
73+
metrics_append(metrics, "catastrophic_seq_fraction", has_catastrophic.int().expand_as(loss_mask))
74+
return veto_mask
75+
76+
77+
def truncate(
78+
weights: torch.Tensor, loss_mask: torch.Tensor, metrics: Dict[str, list[torch.Tensor]], upper_bound: float
79+
) -> torch.Tensor:
80+
assert upper_bound is not None
81+
metrics_append(metrics, "truncate_fraction", (weights > upper_bound).int())
82+
return weights.clamp(0, upper_bound) * loss_mask
83+
84+
85+
def clip(
86+
weights: torch.Tensor,
87+
loss_mask: torch.Tensor,
88+
metrics: Dict[str, list[torch.Tensor]],
89+
lower_bound: float,
90+
upper_bound: float,
91+
) -> torch.Tensor:
92+
assert lower_bound is not None and upper_bound is not None and lower_bound < upper_bound
93+
metrics_append(metrics, "clip_fraction_low", (weights < lower_bound).int())
94+
metrics_append(metrics, "clip_fraction_high", (weights > upper_bound).int())
95+
return weights.clamp(lower_bound, upper_bound) * loss_mask
96+
97+
98+
def mask(
99+
weights: torch.Tensor,
100+
loss_mask: torch.Tensor,
101+
metrics: Dict[str, list[torch.Tensor]],
102+
lower_bound: float,
103+
upper_bound: float,
104+
) -> torch.Tensor:
105+
assert lower_bound is not None and upper_bound is not None and lower_bound < upper_bound
106+
metrics_append(metrics, "mask_fraction_low", (weights < lower_bound).int())
107+
metrics_append(metrics, "mask_fraction_high", (weights > upper_bound).int())
108+
mask = (weights >= lower_bound) & (weights <= upper_bound)
109+
return weights * mask * loss_mask
110+
111+
112+
def compute_mis_weights(
113+
args,
114+
*,
115+
train_log_probs: list[torch.Tensor],
116+
rollout_log_probs: list[torch.Tensor],
117+
loss_masks: list[torch.Tensor],
118+
) -> Tuple[list[torch.Tensor], Dict[str, list[torch.Tensor]]]:
119+
"""
120+
Compute the importance sampling (IS) weights and metrics between the inference and training engine.
121+
Args:
122+
train_log_probs: List of log probs from training backend. 1D tensor each. Lengths can be different.
123+
rollout_log_probs: List of log probs from inference backend. 1D tensor each.
124+
loss_masks: List of loss masks. 1D tensor each.
125+
Note that for single turn RL, the loss_mask is [1] * response_length tensor for each sequence
126+
For multi-turn RL, the tool response will be marked as 0 in the loss_mask.
127+
128+
Returns:
129+
weights: List of importance sampling weights. 1D tensor each.
130+
metrics: The metrics for the importance sampling weights, a dict of list[torch.Tensor]. 1D tensor each.
131+
"""
132+
133+
level: str = args.mis_level
134+
metrics: Dict[str, list[torch.Tensor]] = {}
135+
136+
if args.mis_lower_bound is None:
137+
return 1.0 / args.mis_upper_bound
138+
139+
# Validate input lists have same length and each sequence has matching shapes
140+
assert (
141+
len(train_log_probs) == len(rollout_log_probs) == len(loss_masks)
142+
), f"Input lists must have the same number of sequences: {len(train_log_probs)} vs {len(rollout_log_probs)} vs {len(loss_masks)}"
143+
144+
for i, (train, rollout, loss_mask) in enumerate(zip(train_log_probs, rollout_log_probs, loss_masks)):
145+
assert (
146+
train.shape == rollout.shape == loss_mask.shape
147+
), f"Sequence {i}: shapes must match - train: {train.shape}, rollout: {rollout.shape}, loss_mask: {loss_mask.shape}"
148+
149+
SAFETY_BOUND = 20.0 # Add a safety bound to avoid exp overflow
150+
all_weights = []
151+
152+
# handle each sequence independently
153+
for train_log_prob, rollout_log_prob, loss_mask in zip(train_log_probs, rollout_log_probs, loss_masks):
154+
loss_mask = loss_mask.float()
155+
add_ppl_metrics(train_log_prob, rollout_log_prob, loss_mask, metrics)
156+
raw_log_ratio_diff = train_log_prob - rollout_log_prob
157+
158+
# level: The aggregation level for the importance sampling weights.
159+
if level == "token":
160+
# Per-token ratio (biased)
161+
log_ratio_for_metrics = raw_log_ratio_diff
162+
elif level == "sequence":
163+
# Product of ratios (unbiased but high variance)
164+
log_ratio_for_metrics = masked_sum(raw_log_ratio_diff, loss_mask, expand=True)
165+
elif level == "geometric":
166+
# Geometric mean of ratios (biased but low variance)
167+
log_ratio_for_metrics = masked_mean(raw_log_ratio_diff, loss_mask, expand=True)
168+
else:
169+
raise ValueError(f"Invalid importance sampling level: {level}")
170+
171+
log_ratio_safe = torch.clamp(log_ratio_for_metrics, min=-SAFETY_BOUND, max=SAFETY_BOUND)
172+
weights = torch.exp(log_ratio_safe)
173+
metrics_append(metrics, "mean_is_weight_before_clip", weights)
174+
175+
# mask out catastrophic tokens
176+
if args.mis_veto_threshold is not None:
177+
veto_mask = calculate_veto_mask(raw_log_ratio_diff, loss_mask, args.mis_veto_threshold, metrics)
178+
179+
# mode: how to handle the importance sampling weights exceeding the thresholds.
180+
if args.mis_mode == "truncate":
181+
# Cap the importance sampling weights at the upper threshold
182+
# https://fengyao.notion.site/off-policy-rl#279721e3f6c48092bbe2fcfe0e9c6b33
183+
weights = truncate(weights, loss_mask, metrics, args.mis_upper_bound)
184+
elif args.mis_mode == "mask":
185+
# Zero the importance sampling weights outside the [lower, upper] range.
186+
# https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Training-Inference-Mismatch-271211a558b7808d8b12d403fd15edda
187+
weights = mask(
188+
weights,
189+
loss_mask,
190+
metrics,
191+
args.mis_lower_bound,
192+
args.mis_upper_bound,
193+
)
194+
elif args.mis_mode == "clip":
195+
# Clip the importance sampling weights to the [lower, upper] range.
196+
# Original behavior in slime.
197+
weights = clip(
198+
weights,
199+
loss_mask,
200+
metrics,
201+
args.mis_lower_bound,
202+
args.mis_upper_bound,
203+
)
204+
else:
205+
raise ValueError(f"Unsupported mis_mode: {args.mis_mode}")
206+
207+
metrics_append(metrics, "ratio_mean_after_mis", weights)
208+
if args.mis_veto_threshold is not None:
209+
weights = weights * veto_mask
210+
metrics_append(metrics, "ratio_mean_after_veto_mask", weights)
211+
212+
weights = weights.detach()
213+
all_weights.append(weights)
214+
215+
return all_weights, metrics
216+
217+
218+
def compute_mis_weights_with_cp(
219+
args,
220+
*,
221+
train_log_probs: list[torch.Tensor],
222+
rollout_log_probs: list[torch.Tensor],
223+
loss_masks: list[torch.Tensor],
224+
total_lengths: list[int],
225+
response_lengths: list[int],
226+
**kwargs: Any,
227+
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
228+
"""
229+
Compute the importance sampling (IS) weights and metrics with context parallel.
230+
Args:
231+
train_log_probs: List of log probs from training backend on this cp rank. 1D tensor each. Lengths can be different.
232+
rollout_log_probs: List of log probs from inference backend on this cp rank. 1D tensor each.
233+
loss_masks: List of loss masks. 1D tensor each.
234+
total_lengths: List of total lengths.
235+
response_lengths: List of response lengths.
236+
Returns:
237+
is_weights: Importance sampling weights on this CP rank and flattened along dim=0.
238+
is_metrics: The metrics for the importance sampling weights, a dict of list[torch.Tensor]. 1D tensor each.
239+
Also flattened along dim=0.
240+
"""
241+
# Gather cp slice from other cp ranks
242+
full_rollout_log_probs = [
243+
all_gather_with_cp(log_prob, total_length, response_length)
244+
for log_prob, total_length, response_length in zip(rollout_log_probs, total_lengths, response_lengths)
245+
]
246+
full_old_log_probs = [
247+
all_gather_with_cp(old_log_prob, total_length, response_length)
248+
for old_log_prob, total_length, response_length in zip(train_log_probs, total_lengths, response_lengths)
249+
]
250+
251+
# Main logic for is
252+
is_weights, is_metrics = compute_mis_weights(
253+
args=args,
254+
train_log_probs=full_old_log_probs,
255+
rollout_log_probs=full_rollout_log_probs,
256+
loss_masks=loss_masks,
257+
)
258+
259+
# Slice out the value shards for this CP rank and concat them into a 1D tensor along dim=0 for loss.py computation.
260+
def slice_cp_and_concat(
261+
values: list[torch.Tensor], total_lengths: list[int], response_lengths: list[int]
262+
) -> torch.Tensor:
263+
values = [
264+
# TODO: A rename of this function?
265+
slice_log_prob_with_cp(values[i], total_lengths[i], response_lengths[i])
266+
for i in range(len(values))
267+
]
268+
return torch.cat(values, dim=0)
269+
270+
result_metrics = {}
271+
is_weights = slice_cp_and_concat(is_weights, total_lengths, response_lengths)
272+
for key, values in is_metrics.items():
273+
key_name = f"mis_{key}"
274+
values = slice_cp_and_concat(values, total_lengths, response_lengths)
275+
result_metrics[key_name] = values
276+
277+
return is_weights, result_metrics
278+
279+
280+
def add_ppl_metrics(
281+
train_log_prob: torch.Tensor,
282+
rollout_log_prob: torch.Tensor,
283+
loss_mask: torch.Tensor,
284+
metrics: Dict[str, list[torch.Tensor]],
285+
):
286+
loss_mask = loss_mask.float()
287+
288+
# 1. Training policy perplexity metrics
289+
mean_log_prob_training = masked_mean(train_log_prob, loss_mask, expand=True)
290+
training_log_ppl = -mean_log_prob_training
291+
training_ppl = torch.exp(training_log_ppl)
292+
metrics_append(metrics, "training_log_ppl", training_log_ppl)
293+
metrics_append(metrics, "training_ppl", training_ppl)
294+
295+
# 2. Rollout policy perplexity metrics
296+
mean_log_prob_rollout = masked_mean(rollout_log_prob, loss_mask, expand=True)
297+
rollout_log_ppl = -mean_log_prob_rollout
298+
rollout_ppl = torch.exp(rollout_log_ppl)
299+
metrics_append(metrics, "rollout_log_ppl", rollout_log_ppl)
300+
metrics_append(metrics, "rollout_ppl", rollout_ppl)
301+
302+
# 3a. kl: Direct estimator for KL(π_rollout || π_training)
303+
# This is the standard KL divergence: E[log(π_rollout) - log(π_training)]
304+
# Positive value means rollout policy is more confident than training policy
305+
kl_per_token = rollout_log_prob - train_log_prob
306+
metrics_append(metrics, "kl", kl_per_token)
307+
308+
# 3b. K3 KL estimator for improved stability
309+
# More stable for small KL values using: E[exp(log_ratio) - log_ratio - 1]
310+
# Formula: KL ≈ E[r - log(r) - 1] where r = π_training/π_rollout
311+
log_ratio = train_log_prob - rollout_log_prob
312+
k3_kl_matrix = torch.exp(log_ratio) - log_ratio - 1
313+
metrics_append(metrics, "k3_kl", k3_kl_matrix)
314+
315+
# 3c. Log PPL difference (sequence-level perplexity difference)
316+
# log_ppl_diff = mean_log_prob_rollout - mean_log_prob_training
317+
# Since ppl = exp(-log_prob), we have:
318+
# log(ppl_ratio) = log(training_ppl/rollout_ppl) = log_ppl_diff
319+
# Positive value means training assigns lower probability (higher PPL) than rollout
320+
log_ppl_diff = mean_log_prob_rollout - mean_log_prob_training
321+
metrics_append(metrics, "log_ppl_diff", log_ppl_diff)
322+
metrics_append(metrics, "log_ppl_abs_diff", log_ppl_diff.abs())
323+
324+
# 3d. PPL ratio (how much higher is training PPL vs rollout PPL)
325+
# For numerical stability, compute in log space using log_ppl_diff
326+
# Note: log_ppl_diff = log(ppl_ratio), so ppl_ratio = exp(log_ppl_diff)
327+
ppl_ratio = torch.exp(log_ppl_diff)
328+
metrics_append(metrics, "ppl_ratio", ppl_ratio)
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# Enable importance sampling, details refer to the comments of compute_mis_weights in mis.py
2+
use_mis: false
3+
4+
# Aggregation level for importance sampling weights:
5+
# token: per-token
6+
# sequence: product over tokens
7+
# geometric: geometric mean
8+
mis_level: "token"
9+
10+
# Handling mode for IS weights:
11+
# truncate: cap to upper bound, TIS
12+
# mask: zero outside [lower, upper], MIS
13+
# clip: clip to [lower, upper], CIS
14+
mis_mode: "truncate"
15+
16+
# For mask or clip mode, the lower bound of the IS weights.
17+
# For truncate mode, it will not be used.
18+
# If not set, it will be set to 1.0 / mis_upper_bound
19+
mis_lower_bound: 0.5
20+
21+
# For truncate, mask, or clip mode, the upper bound of the IS weights
22+
mis_upper_bound: 2.0
23+
24+
# Per-token veto threshold. If any token ratio < this, zero the entire sequence weight, the sequences won't have gradient
25+
# Note: float number must be written with dot e.g. 1.0e-4, not 1e-4
26+
mis_veto_threshold: 1.0e-4

0 commit comments

Comments
 (0)