|
| 1 | +from typing import Any, Dict, Optional, Tuple |
| 2 | + |
| 3 | +import torch |
| 4 | + |
| 5 | +from slime.backends.megatron_utils.cp_utils import all_gather_with_cp, slice_log_prob_with_cp |
| 6 | + |
| 7 | + |
| 8 | +def masked_sum(x: torch.Tensor, loss_mask: torch.Tensor, expand: bool = False) -> torch.Tensor: |
| 9 | + result = (x * loss_mask).sum() |
| 10 | + return result.expand_as(x) if expand else result |
| 11 | + |
| 12 | + |
| 13 | +def masked_mean(x: torch.Tensor, loss_mask: torch.Tensor, expand: bool = False) -> torch.Tensor: |
| 14 | + result = masked_sum(x, loss_mask) / torch.clamp_min(loss_mask.sum(), 1) |
| 15 | + return result.expand_as(x) if expand else result |
| 16 | + |
| 17 | + |
| 18 | +def metrics_append(metrics: Dict[str, list[torch.Tensor]], key: str, value: torch.Tensor) -> None: |
| 19 | + """ |
| 20 | +
|
| 21 | + Every metrics-dict value is a list of 1D tensor, i.e., [torch.Tensor] with shapes exactly the same as log_probs. |
| 22 | +
|
| 23 | + All metrics will be aggregated and averaged by `sum_of_sample_mean` and divided by DP size automatically |
| 24 | + - If calculate_per_token_loss=False (default), the final results will first be averaged in each sequence, |
| 25 | + then across all the sequences in the global batch. |
| 26 | + - If calculate_per_token_loss=True, the final results will be the mean of all the tokens in the global batch. |
| 27 | +
|
| 28 | + No need to specifically handle loss_mask, sum_of_sample_mean automatically ignores statistics where loss_mask = 0. |
| 29 | +
|
| 30 | + e.g. |
| 31 | + For token-level metric: |
| 32 | + value = [ |
| 33 | + [0.1, 0.2], |
| 34 | + [0.1, 0.2, 0.3, 0.4, 0.5], |
| 35 | + [0.6] |
| 36 | + ] |
| 37 | + When calculate_per_token_loss = False (default): |
| 38 | + result = (0.1 + 0.2) / 2 + (0.1 + 0.2 + 0.3 + 0.4 + 0.5) / 5 + (0.6) / 1 = 0.15 + 0.3 + 0.6 = 1.05 / 3 = 0.35 |
| 39 | + When calculate_per_token_loss = True: |
| 40 | + result = (0.1 + 0.2 + 0.1 + 0.2 + 0.3 + 0.4 + 0.5 + 0.6) / 8 = 2.4 / 8 = 0.3 |
| 41 | + For sequence-level metric: |
| 42 | + original sequence lengths = [2, 5, 2] |
| 43 | + We should expand the metrics to the length of each sequence: |
| 44 | + value = [ |
| 45 | + [2, 2], |
| 46 | + [5, 5, 5, 5, 5], |
| 47 | + [1, 1] |
| 48 | + ] |
| 49 | + When calculate_per_token_loss = False (default): |
| 50 | + result = (2 + 2) / 2 + (5 + 5 + 5 + 5 + 5) / 5 + (1 + 1) / 2 = 2 + 5 + 1 = 8 / 3 = 2.6665 |
| 51 | + Note that for sequence-level, calculating token-level loss is invalid; thus, calculate_per_token_loss should always be False. |
| 52 | + """ |
| 53 | + if key not in metrics: |
| 54 | + metrics[key] = [] |
| 55 | + metrics[key].append(value.clone().detach()) |
| 56 | + |
| 57 | + |
| 58 | +def calculate_veto_mask( |
| 59 | + log_ratio: torch.Tensor, |
| 60 | + loss_mask: torch.Tensor, |
| 61 | + veto_threshold: Optional[float], |
| 62 | + metrics: Dict[str, list[torch.Tensor]], |
| 63 | +) -> torch.Tensor: |
| 64 | + if veto_threshold is None: |
| 65 | + return torch.ones_like(log_ratio) |
| 66 | + log_veto_threshold = torch.log(torch.tensor(veto_threshold, device=log_ratio.device)) |
| 67 | + # For each sequence, if it has any catastrophic tokens, return 0 for the sequence |
| 68 | + catastrophic_tokens = ((log_ratio < log_veto_threshold)) & loss_mask.bool() |
| 69 | + has_catastrophic = catastrophic_tokens.any() |
| 70 | + veto_mask = (~has_catastrophic).float().expand_as(log_ratio) |
| 71 | + |
| 72 | + metrics_append(metrics, "catastrophic_token_fraction", catastrophic_tokens.int()) |
| 73 | + metrics_append(metrics, "catastrophic_seq_fraction", has_catastrophic.int().expand_as(loss_mask)) |
| 74 | + return veto_mask |
| 75 | + |
| 76 | + |
| 77 | +def truncate( |
| 78 | + weights: torch.Tensor, loss_mask: torch.Tensor, metrics: Dict[str, list[torch.Tensor]], upper_bound: float |
| 79 | +) -> torch.Tensor: |
| 80 | + assert upper_bound is not None |
| 81 | + metrics_append(metrics, "truncate_fraction", (weights > upper_bound).int()) |
| 82 | + return weights.clamp(0, upper_bound) * loss_mask |
| 83 | + |
| 84 | + |
| 85 | +def clip( |
| 86 | + weights: torch.Tensor, |
| 87 | + loss_mask: torch.Tensor, |
| 88 | + metrics: Dict[str, list[torch.Tensor]], |
| 89 | + lower_bound: float, |
| 90 | + upper_bound: float, |
| 91 | +) -> torch.Tensor: |
| 92 | + assert lower_bound is not None and upper_bound is not None and lower_bound < upper_bound |
| 93 | + metrics_append(metrics, "clip_fraction_low", (weights < lower_bound).int()) |
| 94 | + metrics_append(metrics, "clip_fraction_high", (weights > upper_bound).int()) |
| 95 | + return weights.clamp(lower_bound, upper_bound) * loss_mask |
| 96 | + |
| 97 | + |
| 98 | +def mask( |
| 99 | + weights: torch.Tensor, |
| 100 | + loss_mask: torch.Tensor, |
| 101 | + metrics: Dict[str, list[torch.Tensor]], |
| 102 | + lower_bound: float, |
| 103 | + upper_bound: float, |
| 104 | +) -> torch.Tensor: |
| 105 | + assert lower_bound is not None and upper_bound is not None and lower_bound < upper_bound |
| 106 | + metrics_append(metrics, "mask_fraction_low", (weights < lower_bound).int()) |
| 107 | + metrics_append(metrics, "mask_fraction_high", (weights > upper_bound).int()) |
| 108 | + mask = (weights >= lower_bound) & (weights <= upper_bound) |
| 109 | + return weights * mask * loss_mask |
| 110 | + |
| 111 | + |
| 112 | +def compute_mis_weights( |
| 113 | + args, |
| 114 | + *, |
| 115 | + train_log_probs: list[torch.Tensor], |
| 116 | + rollout_log_probs: list[torch.Tensor], |
| 117 | + loss_masks: list[torch.Tensor], |
| 118 | +) -> Tuple[list[torch.Tensor], Dict[str, list[torch.Tensor]]]: |
| 119 | + """ |
| 120 | + Compute the importance sampling (IS) weights and metrics between the inference and training engine. |
| 121 | + Args: |
| 122 | + train_log_probs: List of log probs from training backend. 1D tensor each. Lengths can be different. |
| 123 | + rollout_log_probs: List of log probs from inference backend. 1D tensor each. |
| 124 | + loss_masks: List of loss masks. 1D tensor each. |
| 125 | + Note that for single turn RL, the loss_mask is [1] * response_length tensor for each sequence |
| 126 | + For multi-turn RL, the tool response will be marked as 0 in the loss_mask. |
| 127 | +
|
| 128 | + Returns: |
| 129 | + weights: List of importance sampling weights. 1D tensor each. |
| 130 | + metrics: The metrics for the importance sampling weights, a dict of list[torch.Tensor]. 1D tensor each. |
| 131 | + """ |
| 132 | + |
| 133 | + level: str = args.mis_level |
| 134 | + metrics: Dict[str, list[torch.Tensor]] = {} |
| 135 | + |
| 136 | + if args.mis_lower_bound is None: |
| 137 | + return 1.0 / args.mis_upper_bound |
| 138 | + |
| 139 | + # Validate input lists have same length and each sequence has matching shapes |
| 140 | + assert ( |
| 141 | + len(train_log_probs) == len(rollout_log_probs) == len(loss_masks) |
| 142 | + ), f"Input lists must have the same number of sequences: {len(train_log_probs)} vs {len(rollout_log_probs)} vs {len(loss_masks)}" |
| 143 | + |
| 144 | + for i, (train, rollout, loss_mask) in enumerate(zip(train_log_probs, rollout_log_probs, loss_masks)): |
| 145 | + assert ( |
| 146 | + train.shape == rollout.shape == loss_mask.shape |
| 147 | + ), f"Sequence {i}: shapes must match - train: {train.shape}, rollout: {rollout.shape}, loss_mask: {loss_mask.shape}" |
| 148 | + |
| 149 | + SAFETY_BOUND = 20.0 # Add a safety bound to avoid exp overflow |
| 150 | + all_weights = [] |
| 151 | + |
| 152 | + # handle each sequence independently |
| 153 | + for train_log_prob, rollout_log_prob, loss_mask in zip(train_log_probs, rollout_log_probs, loss_masks): |
| 154 | + loss_mask = loss_mask.float() |
| 155 | + add_ppl_metrics(train_log_prob, rollout_log_prob, loss_mask, metrics) |
| 156 | + raw_log_ratio_diff = train_log_prob - rollout_log_prob |
| 157 | + |
| 158 | + # level: The aggregation level for the importance sampling weights. |
| 159 | + if level == "token": |
| 160 | + # Per-token ratio (biased) |
| 161 | + log_ratio_for_metrics = raw_log_ratio_diff |
| 162 | + elif level == "sequence": |
| 163 | + # Product of ratios (unbiased but high variance) |
| 164 | + log_ratio_for_metrics = masked_sum(raw_log_ratio_diff, loss_mask, expand=True) |
| 165 | + elif level == "geometric": |
| 166 | + # Geometric mean of ratios (biased but low variance) |
| 167 | + log_ratio_for_metrics = masked_mean(raw_log_ratio_diff, loss_mask, expand=True) |
| 168 | + else: |
| 169 | + raise ValueError(f"Invalid importance sampling level: {level}") |
| 170 | + |
| 171 | + log_ratio_safe = torch.clamp(log_ratio_for_metrics, min=-SAFETY_BOUND, max=SAFETY_BOUND) |
| 172 | + weights = torch.exp(log_ratio_safe) |
| 173 | + metrics_append(metrics, "mean_is_weight_before_clip", weights) |
| 174 | + |
| 175 | + # mask out catastrophic tokens |
| 176 | + if args.mis_veto_threshold is not None: |
| 177 | + veto_mask = calculate_veto_mask(raw_log_ratio_diff, loss_mask, args.mis_veto_threshold, metrics) |
| 178 | + |
| 179 | + # mode: how to handle the importance sampling weights exceeding the thresholds. |
| 180 | + if args.mis_mode == "truncate": |
| 181 | + # Cap the importance sampling weights at the upper threshold |
| 182 | + # https://fengyao.notion.site/off-policy-rl#279721e3f6c48092bbe2fcfe0e9c6b33 |
| 183 | + weights = truncate(weights, loss_mask, metrics, args.mis_upper_bound) |
| 184 | + elif args.mis_mode == "mask": |
| 185 | + # Zero the importance sampling weights outside the [lower, upper] range. |
| 186 | + # https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Training-Inference-Mismatch-271211a558b7808d8b12d403fd15edda |
| 187 | + weights = mask( |
| 188 | + weights, |
| 189 | + loss_mask, |
| 190 | + metrics, |
| 191 | + args.mis_lower_bound, |
| 192 | + args.mis_upper_bound, |
| 193 | + ) |
| 194 | + elif args.mis_mode == "clip": |
| 195 | + # Clip the importance sampling weights to the [lower, upper] range. |
| 196 | + # Original behavior in slime. |
| 197 | + weights = clip( |
| 198 | + weights, |
| 199 | + loss_mask, |
| 200 | + metrics, |
| 201 | + args.mis_lower_bound, |
| 202 | + args.mis_upper_bound, |
| 203 | + ) |
| 204 | + else: |
| 205 | + raise ValueError(f"Unsupported mis_mode: {args.mis_mode}") |
| 206 | + |
| 207 | + metrics_append(metrics, "ratio_mean_after_mis", weights) |
| 208 | + if args.mis_veto_threshold is not None: |
| 209 | + weights = weights * veto_mask |
| 210 | + metrics_append(metrics, "ratio_mean_after_veto_mask", weights) |
| 211 | + |
| 212 | + weights = weights.detach() |
| 213 | + all_weights.append(weights) |
| 214 | + |
| 215 | + return all_weights, metrics |
| 216 | + |
| 217 | + |
| 218 | +def compute_mis_weights_with_cp( |
| 219 | + args, |
| 220 | + *, |
| 221 | + train_log_probs: list[torch.Tensor], |
| 222 | + rollout_log_probs: list[torch.Tensor], |
| 223 | + loss_masks: list[torch.Tensor], |
| 224 | + total_lengths: list[int], |
| 225 | + response_lengths: list[int], |
| 226 | + **kwargs: Any, |
| 227 | +) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: |
| 228 | + """ |
| 229 | + Compute the importance sampling (IS) weights and metrics with context parallel. |
| 230 | + Args: |
| 231 | + train_log_probs: List of log probs from training backend on this cp rank. 1D tensor each. Lengths can be different. |
| 232 | + rollout_log_probs: List of log probs from inference backend on this cp rank. 1D tensor each. |
| 233 | + loss_masks: List of loss masks. 1D tensor each. |
| 234 | + total_lengths: List of total lengths. |
| 235 | + response_lengths: List of response lengths. |
| 236 | + Returns: |
| 237 | + is_weights: Importance sampling weights on this CP rank and flattened along dim=0. |
| 238 | + is_metrics: The metrics for the importance sampling weights, a dict of list[torch.Tensor]. 1D tensor each. |
| 239 | + Also flattened along dim=0. |
| 240 | + """ |
| 241 | + # Gather cp slice from other cp ranks |
| 242 | + full_rollout_log_probs = [ |
| 243 | + all_gather_with_cp(log_prob, total_length, response_length) |
| 244 | + for log_prob, total_length, response_length in zip(rollout_log_probs, total_lengths, response_lengths) |
| 245 | + ] |
| 246 | + full_old_log_probs = [ |
| 247 | + all_gather_with_cp(old_log_prob, total_length, response_length) |
| 248 | + for old_log_prob, total_length, response_length in zip(train_log_probs, total_lengths, response_lengths) |
| 249 | + ] |
| 250 | + |
| 251 | + # Main logic for is |
| 252 | + is_weights, is_metrics = compute_mis_weights( |
| 253 | + args=args, |
| 254 | + train_log_probs=full_old_log_probs, |
| 255 | + rollout_log_probs=full_rollout_log_probs, |
| 256 | + loss_masks=loss_masks, |
| 257 | + ) |
| 258 | + |
| 259 | + # Slice out the value shards for this CP rank and concat them into a 1D tensor along dim=0 for loss.py computation. |
| 260 | + def slice_cp_and_concat( |
| 261 | + values: list[torch.Tensor], total_lengths: list[int], response_lengths: list[int] |
| 262 | + ) -> torch.Tensor: |
| 263 | + values = [ |
| 264 | + # TODO: A rename of this function? |
| 265 | + slice_log_prob_with_cp(values[i], total_lengths[i], response_lengths[i]) |
| 266 | + for i in range(len(values)) |
| 267 | + ] |
| 268 | + return torch.cat(values, dim=0) |
| 269 | + |
| 270 | + result_metrics = {} |
| 271 | + is_weights = slice_cp_and_concat(is_weights, total_lengths, response_lengths) |
| 272 | + for key, values in is_metrics.items(): |
| 273 | + key_name = f"mis_{key}" |
| 274 | + values = slice_cp_and_concat(values, total_lengths, response_lengths) |
| 275 | + result_metrics[key_name] = values |
| 276 | + |
| 277 | + return is_weights, result_metrics |
| 278 | + |
| 279 | + |
| 280 | +def add_ppl_metrics( |
| 281 | + train_log_prob: torch.Tensor, |
| 282 | + rollout_log_prob: torch.Tensor, |
| 283 | + loss_mask: torch.Tensor, |
| 284 | + metrics: Dict[str, list[torch.Tensor]], |
| 285 | +): |
| 286 | + loss_mask = loss_mask.float() |
| 287 | + |
| 288 | + # 1. Training policy perplexity metrics |
| 289 | + mean_log_prob_training = masked_mean(train_log_prob, loss_mask, expand=True) |
| 290 | + training_log_ppl = -mean_log_prob_training |
| 291 | + training_ppl = torch.exp(training_log_ppl) |
| 292 | + metrics_append(metrics, "training_log_ppl", training_log_ppl) |
| 293 | + metrics_append(metrics, "training_ppl", training_ppl) |
| 294 | + |
| 295 | + # 2. Rollout policy perplexity metrics |
| 296 | + mean_log_prob_rollout = masked_mean(rollout_log_prob, loss_mask, expand=True) |
| 297 | + rollout_log_ppl = -mean_log_prob_rollout |
| 298 | + rollout_ppl = torch.exp(rollout_log_ppl) |
| 299 | + metrics_append(metrics, "rollout_log_ppl", rollout_log_ppl) |
| 300 | + metrics_append(metrics, "rollout_ppl", rollout_ppl) |
| 301 | + |
| 302 | + # 3a. kl: Direct estimator for KL(π_rollout || π_training) |
| 303 | + # This is the standard KL divergence: E[log(π_rollout) - log(π_training)] |
| 304 | + # Positive value means rollout policy is more confident than training policy |
| 305 | + kl_per_token = rollout_log_prob - train_log_prob |
| 306 | + metrics_append(metrics, "kl", kl_per_token) |
| 307 | + |
| 308 | + # 3b. K3 KL estimator for improved stability |
| 309 | + # More stable for small KL values using: E[exp(log_ratio) - log_ratio - 1] |
| 310 | + # Formula: KL ≈ E[r - log(r) - 1] where r = π_training/π_rollout |
| 311 | + log_ratio = train_log_prob - rollout_log_prob |
| 312 | + k3_kl_matrix = torch.exp(log_ratio) - log_ratio - 1 |
| 313 | + metrics_append(metrics, "k3_kl", k3_kl_matrix) |
| 314 | + |
| 315 | + # 3c. Log PPL difference (sequence-level perplexity difference) |
| 316 | + # log_ppl_diff = mean_log_prob_rollout - mean_log_prob_training |
| 317 | + # Since ppl = exp(-log_prob), we have: |
| 318 | + # log(ppl_ratio) = log(training_ppl/rollout_ppl) = log_ppl_diff |
| 319 | + # Positive value means training assigns lower probability (higher PPL) than rollout |
| 320 | + log_ppl_diff = mean_log_prob_rollout - mean_log_prob_training |
| 321 | + metrics_append(metrics, "log_ppl_diff", log_ppl_diff) |
| 322 | + metrics_append(metrics, "log_ppl_abs_diff", log_ppl_diff.abs()) |
| 323 | + |
| 324 | + # 3d. PPL ratio (how much higher is training PPL vs rollout PPL) |
| 325 | + # For numerical stability, compute in log space using log_ppl_diff |
| 326 | + # Note: log_ppl_diff = log(ppl_ratio), so ppl_ratio = exp(log_ppl_diff) |
| 327 | + ppl_ratio = torch.exp(log_ppl_diff) |
| 328 | + metrics_append(metrics, "ppl_ratio", ppl_ratio) |
0 commit comments