Skip to content

Commit 57b9c3d

Browse files
ChangyiYangslime-local
andauthored
[TIS/MIS] fix and add better metric (#1174)
Co-authored-by: slime-local <slime-local@localhost>
1 parent b74858a commit 57b9c3d

File tree

2 files changed

+83
-17
lines changed
  • examples/train_infer_mismatch_helper
  • slime/backends/megatron_utils

2 files changed

+83
-17
lines changed

examples/train_infer_mismatch_helper/mis.py

Lines changed: 69 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,11 @@
22

33
import torch
44

5-
from slime.backends.megatron_utils.cp_utils import all_gather_with_cp, slice_log_prob_with_cp
5+
# NOTE:
6+
# - `compute_mis_weights` is a lightweight, standalone function that is useful to unit-test on CPU.
7+
# - `compute_mis_weights_with_cp` depends on Megatron context-parallel utilities, which are heavy and may not be
8+
# available in minimal environments.
9+
# To keep `mis.py` importable for unit tests, we lazily import CP utilities inside `compute_mis_weights_with_cp`.
610

711

812
def masked_sum(x: torch.Tensor, loss_mask: torch.Tensor, expand: bool = False) -> torch.Tensor:
@@ -15,6 +19,26 @@ def masked_mean(x: torch.Tensor, loss_mask: torch.Tensor, expand: bool = False)
1519
return result.expand_as(x) if expand else result
1620

1721

22+
def masked_min(x: torch.Tensor, loss_mask: torch.Tensor, expand: bool = False) -> torch.Tensor:
23+
"""Masked min over valid tokens (loss_mask == 1). Returns 0 when mask is empty."""
24+
mask = loss_mask.bool()
25+
if mask.any():
26+
result = x[mask].min()
27+
else:
28+
result = torch.tensor(0.0, device=x.device, dtype=x.dtype)
29+
return result.expand_as(x) if expand else result
30+
31+
32+
def masked_max(x: torch.Tensor, loss_mask: torch.Tensor, expand: bool = False) -> torch.Tensor:
33+
"""Masked max over valid tokens (loss_mask == 1). Returns 0 when mask is empty."""
34+
mask = loss_mask.bool()
35+
if mask.any():
36+
result = x[mask].max()
37+
else:
38+
result = torch.tensor(0.0, device=x.device, dtype=x.dtype)
39+
return result.expand_as(x) if expand else result
40+
41+
1842
def metrics_append(metrics: dict[str, list[torch.Tensor]], key: str, value: torch.Tensor) -> None:
1943
"""
2044
@@ -60,6 +84,8 @@ def calculate_veto_mask(
6084
loss_mask: torch.Tensor,
6185
veto_threshold: float | None,
6286
metrics: dict[str, list[torch.Tensor]],
87+
*,
88+
metric_prefix: str = "",
6389
) -> torch.Tensor:
6490
if veto_threshold is None:
6591
return torch.ones_like(log_ratio)
@@ -69,16 +95,21 @@ def calculate_veto_mask(
6995
has_catastrophic = catastrophic_tokens.any()
7096
veto_mask = (~has_catastrophic).float().expand_as(log_ratio)
7197

72-
metrics_append(metrics, "catastrophic_token_fraction", catastrophic_tokens.int())
73-
metrics_append(metrics, "catastrophic_seq_fraction", has_catastrophic.int().expand_as(loss_mask))
98+
metrics_append(metrics, f"{metric_prefix}catastrophic_token_fraction", catastrophic_tokens.int())
99+
metrics_append(metrics, f"{metric_prefix}catastrophic_seq_fraction", has_catastrophic.int().expand_as(loss_mask))
74100
return veto_mask
75101

76102

77103
def truncate(
78-
weights: torch.Tensor, loss_mask: torch.Tensor, metrics: dict[str, list[torch.Tensor]], upper_bound: float
104+
weights: torch.Tensor,
105+
loss_mask: torch.Tensor,
106+
metrics: dict[str, list[torch.Tensor]],
107+
upper_bound: float,
108+
*,
109+
metric_prefix: str = "",
79110
) -> torch.Tensor:
80111
assert upper_bound is not None
81-
metrics_append(metrics, "truncate_fraction", (weights > upper_bound).int())
112+
metrics_append(metrics, f"{metric_prefix}truncate_fraction", (weights > upper_bound).int())
82113
return weights.clamp(0, upper_bound) * loss_mask
83114

84115

@@ -88,10 +119,12 @@ def clip(
88119
metrics: dict[str, list[torch.Tensor]],
89120
lower_bound: float,
90121
upper_bound: float,
122+
*,
123+
metric_prefix: str = "",
91124
) -> torch.Tensor:
92125
assert lower_bound is not None and upper_bound is not None and lower_bound < upper_bound
93-
metrics_append(metrics, "clip_fraction_low", (weights < lower_bound).int())
94-
metrics_append(metrics, "clip_fraction_high", (weights > upper_bound).int())
126+
metrics_append(metrics, f"{metric_prefix}clip_fraction_low", (weights < lower_bound).int())
127+
metrics_append(metrics, f"{metric_prefix}clip_fraction_high", (weights > upper_bound).int())
95128
return weights.clamp(lower_bound, upper_bound) * loss_mask
96129

97130

@@ -101,10 +134,12 @@ def mask(
101134
metrics: dict[str, list[torch.Tensor]],
102135
lower_bound: float,
103136
upper_bound: float,
137+
*,
138+
metric_prefix: str = "",
104139
) -> tuple[torch.Tensor, torch.Tensor]:
105140
assert lower_bound is not None and upper_bound is not None and lower_bound < upper_bound
106-
metrics_append(metrics, "mask_fraction_low", (weights < lower_bound).int())
107-
metrics_append(metrics, "mask_fraction_high", (weights > upper_bound).int())
141+
metrics_append(metrics, f"{metric_prefix}mask_fraction_low", (weights < lower_bound).int())
142+
metrics_append(metrics, f"{metric_prefix}mask_fraction_high", (weights > upper_bound).int())
108143
in_range = (weights >= lower_bound) & (weights <= upper_bound)
109144
modified_mask = loss_mask * in_range.float()
110145
# Zero out padding in weights but preserve values at non-rejected positions
@@ -189,11 +224,15 @@ def compute_log_ratio(raw_log_diff: torch.Tensor, mask: torch.Tensor, level: str
189224
metrics_append(metrics, "tis_weight_before_bound", weights)
190225

191226
if args.tis_mode == "truncate":
192-
weights = truncate(weights, loss_mask, metrics, args.tis_upper_bound)
227+
weights = truncate(weights, loss_mask, metrics, args.tis_upper_bound, metric_prefix="tis_")
193228
elif args.tis_mode == "clip":
194-
weights = clip(weights, loss_mask, metrics, tis_lower_bound, args.tis_upper_bound)
229+
weights = clip(
230+
weights, loss_mask, metrics, tis_lower_bound, args.tis_upper_bound, metric_prefix="tis_"
231+
)
195232
elif args.tis_mode == "mask":
196-
weights, modified_mask = mask(weights, loss_mask, metrics, tis_lower_bound, args.tis_upper_bound)
233+
weights, modified_mask = mask(
234+
weights, loss_mask, metrics, tis_lower_bound, args.tis_upper_bound, metric_prefix="tis_"
235+
)
197236
else:
198237
raise ValueError(f"Unsupported tis_mode: {args.tis_mode}")
199238

@@ -212,14 +251,18 @@ def compute_log_ratio(raw_log_diff: torch.Tensor, mask: torch.Tensor, level: str
212251
rs_weights = torch.exp(log_ratio_safe_rs)
213252

214253
# Apply mask-based rejection sampling
215-
_, modified_mask = mask(rs_weights, modified_mask, metrics, rs_lower_bound, rs_upper_bound)
254+
_, modified_mask = mask(
255+
rs_weights, modified_mask, metrics, rs_lower_bound, rs_upper_bound, metric_prefix="rs_"
256+
)
216257

217258
# Veto on raw per-token ratios (sequence-wise rejection)
218259
if args.rs_veto_threshold is not None:
219-
veto_mask = calculate_veto_mask(raw_log_ratio_diff, loss_mask, args.rs_veto_threshold, metrics)
260+
veto_mask = calculate_veto_mask(
261+
raw_log_ratio_diff, loss_mask, args.rs_veto_threshold, metrics, metric_prefix="rs_"
262+
)
220263
modified_mask = modified_mask * veto_mask
221264

222-
metrics_append(metrics, "ratio_mean_after_tis", weights)
265+
metrics_append(metrics, "is_ratio_mean_after_tis_rs", weights)
223266

224267
weights = weights.detach()
225268
modified_mask = modified_mask.detach()
@@ -253,6 +296,14 @@ def compute_log_ratio(raw_log_diff: torch.Tensor, mask: torch.Tensor, level: str
253296
for w in all_weights:
254297
metrics_append(metrics, "batch_norm_factor", torch.ones_like(w))
255298

299+
# Final weight stats (after optional batch normalization).
300+
# NOTE: These are expanded to token-shape so that the existing mean-reducer can aggregate them.
301+
for w, m in zip(all_weights, loss_masks, strict=False):
302+
m = m.float()
303+
metrics_append(metrics, "is_ratio_mean_final", masked_mean(w, m, expand=True))
304+
metrics_append(metrics, "is_ratio_min_final", masked_min(w, m, expand=True))
305+
metrics_append(metrics, "is_ratio_max_final", masked_max(w, m, expand=True))
306+
256307
return all_weights, all_modified_masks, metrics
257308

258309

@@ -280,6 +331,9 @@ def compute_mis_weights_with_cp(
280331
modified_masks: List of modified response masks with rejection applied (one per sequence).
281332
is_metrics: The metrics for the importance sampling weights, a dict of flattened tensors.
282333
"""
334+
# Lazy import to avoid importing Megatron dependencies when only `compute_mis_weights` is used.
335+
from slime.backends.megatron_utils.cp_utils import all_gather_with_cp, slice_log_prob_with_cp
336+
283337
# Gather cp slice from other cp ranks
284338
full_rollout_log_probs = [
285339
all_gather_with_cp(log_prob, total_length, response_length)

slime/backends/megatron_utils/loss.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -507,6 +507,16 @@ def policy_loss_function(
507507

508508
# Apply off-policy correction using importance sampling if enabled
509509
if args.get_mismatch_metrics or args.use_tis:
510+
# NOTE:
511+
# `tis_func` may apply rejection-sampling style masking (RS) and return `modified_response_masks`.
512+
# We rebuild `sum_of_sample_mean` with those masks to correct denominators for loss/backprop.
513+
#
514+
# However, mismatch/TIS/RS metrics (e.g., "truncate_fraction") are often defined over the
515+
# *pre-RS* valid tokens. If we aggregate metrics with `modified_response_masks`, the rejected
516+
# tokens are excluded from the denominator and the metric can be artificially driven to 0.
517+
# Keep a copy of the original reducer (based on `batch["loss_masks"]`) for metric aggregation.
518+
sum_of_sample_mean_for_mismatch_metrics = sum_of_sample_mean
519+
510520
assert "rollout_log_probs" in batch, "rollout_log_probs must be provided for TIS"
511521

512522
ois = (-ppo_kl).exp()
@@ -583,11 +593,13 @@ def policy_loss_function(
583593
reported_loss["kl_loss"] = kl_loss.clone().detach()
584594

585595
if args.get_mismatch_metrics or args.use_tis:
586-
reported_loss["ois"] = sum_of_sample_mean(ois).clone().detach()
596+
# Aggregate mismatch/TIS/RS related metrics with the *pre-RS* masks.
597+
# See comment above where `sum_of_sample_mean_for_mismatch_metrics` is defined.
598+
reported_loss["ois"] = sum_of_sample_mean_for_mismatch_metrics(ois).clone().detach()
587599
# Assume all metrics are already cloned and detached
588600
for metric_key, metric_value in tis_metrics.items():
589601
key_name = f"{metric_key}"
590-
reported_loss[key_name] = sum_of_sample_mean(metric_value)
602+
reported_loss[key_name] = sum_of_sample_mean_for_mismatch_metrics(metric_value)
591603

592604
if args.use_opsm:
593605
reported_loss["opsm_clipfrac"] = opsm_clipfrac

0 commit comments

Comments
 (0)