22
33import torch
44
5- from slime .backends .megatron_utils .cp_utils import all_gather_with_cp , slice_log_prob_with_cp
5+ # NOTE:
6+ # - `compute_mis_weights` is a lightweight, standalone function that is useful to unit-test on CPU.
7+ # - `compute_mis_weights_with_cp` depends on Megatron context-parallel utilities, which are heavy and may not be
8+ # available in minimal environments.
9+ # To keep `mis.py` importable for unit tests, we lazily import CP utilities inside `compute_mis_weights_with_cp`.
610
711
812def masked_sum (x : torch .Tensor , loss_mask : torch .Tensor , expand : bool = False ) -> torch .Tensor :
@@ -15,6 +19,26 @@ def masked_mean(x: torch.Tensor, loss_mask: torch.Tensor, expand: bool = False)
1519 return result .expand_as (x ) if expand else result
1620
1721
22+ def masked_min (x : torch .Tensor , loss_mask : torch .Tensor , expand : bool = False ) -> torch .Tensor :
23+ """Masked min over valid tokens (loss_mask == 1). Returns 0 when mask is empty."""
24+ mask = loss_mask .bool ()
25+ if mask .any ():
26+ result = x [mask ].min ()
27+ else :
28+ result = torch .tensor (0.0 , device = x .device , dtype = x .dtype )
29+ return result .expand_as (x ) if expand else result
30+
31+
32+ def masked_max (x : torch .Tensor , loss_mask : torch .Tensor , expand : bool = False ) -> torch .Tensor :
33+ """Masked max over valid tokens (loss_mask == 1). Returns 0 when mask is empty."""
34+ mask = loss_mask .bool ()
35+ if mask .any ():
36+ result = x [mask ].max ()
37+ else :
38+ result = torch .tensor (0.0 , device = x .device , dtype = x .dtype )
39+ return result .expand_as (x ) if expand else result
40+
41+
1842def metrics_append (metrics : dict [str , list [torch .Tensor ]], key : str , value : torch .Tensor ) -> None :
1943 """
2044
@@ -60,6 +84,8 @@ def calculate_veto_mask(
6084 loss_mask : torch .Tensor ,
6185 veto_threshold : float | None ,
6286 metrics : dict [str , list [torch .Tensor ]],
87+ * ,
88+ metric_prefix : str = "" ,
6389) -> torch .Tensor :
6490 if veto_threshold is None :
6591 return torch .ones_like (log_ratio )
@@ -69,16 +95,21 @@ def calculate_veto_mask(
6995 has_catastrophic = catastrophic_tokens .any ()
7096 veto_mask = (~ has_catastrophic ).float ().expand_as (log_ratio )
7197
72- metrics_append (metrics , " catastrophic_token_fraction" , catastrophic_tokens .int ())
73- metrics_append (metrics , " catastrophic_seq_fraction" , has_catastrophic .int ().expand_as (loss_mask ))
98+ metrics_append (metrics , f" { metric_prefix } catastrophic_token_fraction" , catastrophic_tokens .int ())
99+ metrics_append (metrics , f" { metric_prefix } catastrophic_seq_fraction" , has_catastrophic .int ().expand_as (loss_mask ))
74100 return veto_mask
75101
76102
77103def truncate (
78- weights : torch .Tensor , loss_mask : torch .Tensor , metrics : dict [str , list [torch .Tensor ]], upper_bound : float
104+ weights : torch .Tensor ,
105+ loss_mask : torch .Tensor ,
106+ metrics : dict [str , list [torch .Tensor ]],
107+ upper_bound : float ,
108+ * ,
109+ metric_prefix : str = "" ,
79110) -> torch .Tensor :
80111 assert upper_bound is not None
81- metrics_append (metrics , " truncate_fraction" , (weights > upper_bound ).int ())
112+ metrics_append (metrics , f" { metric_prefix } truncate_fraction" , (weights > upper_bound ).int ())
82113 return weights .clamp (0 , upper_bound ) * loss_mask
83114
84115
@@ -88,10 +119,12 @@ def clip(
88119 metrics : dict [str , list [torch .Tensor ]],
89120 lower_bound : float ,
90121 upper_bound : float ,
122+ * ,
123+ metric_prefix : str = "" ,
91124) -> torch .Tensor :
92125 assert lower_bound is not None and upper_bound is not None and lower_bound < upper_bound
93- metrics_append (metrics , " clip_fraction_low" , (weights < lower_bound ).int ())
94- metrics_append (metrics , " clip_fraction_high" , (weights > upper_bound ).int ())
126+ metrics_append (metrics , f" { metric_prefix } clip_fraction_low" , (weights < lower_bound ).int ())
127+ metrics_append (metrics , f" { metric_prefix } clip_fraction_high" , (weights > upper_bound ).int ())
95128 return weights .clamp (lower_bound , upper_bound ) * loss_mask
96129
97130
@@ -101,10 +134,12 @@ def mask(
101134 metrics : dict [str , list [torch .Tensor ]],
102135 lower_bound : float ,
103136 upper_bound : float ,
137+ * ,
138+ metric_prefix : str = "" ,
104139) -> tuple [torch .Tensor , torch .Tensor ]:
105140 assert lower_bound is not None and upper_bound is not None and lower_bound < upper_bound
106- metrics_append (metrics , " mask_fraction_low" , (weights < lower_bound ).int ())
107- metrics_append (metrics , " mask_fraction_high" , (weights > upper_bound ).int ())
141+ metrics_append (metrics , f" { metric_prefix } mask_fraction_low" , (weights < lower_bound ).int ())
142+ metrics_append (metrics , f" { metric_prefix } mask_fraction_high" , (weights > upper_bound ).int ())
108143 in_range = (weights >= lower_bound ) & (weights <= upper_bound )
109144 modified_mask = loss_mask * in_range .float ()
110145 # Zero out padding in weights but preserve values at non-rejected positions
@@ -189,11 +224,15 @@ def compute_log_ratio(raw_log_diff: torch.Tensor, mask: torch.Tensor, level: str
189224 metrics_append (metrics , "tis_weight_before_bound" , weights )
190225
191226 if args .tis_mode == "truncate" :
192- weights = truncate (weights , loss_mask , metrics , args .tis_upper_bound )
227+ weights = truncate (weights , loss_mask , metrics , args .tis_upper_bound , metric_prefix = "tis_" )
193228 elif args .tis_mode == "clip" :
194- weights = clip (weights , loss_mask , metrics , tis_lower_bound , args .tis_upper_bound )
229+ weights = clip (
230+ weights , loss_mask , metrics , tis_lower_bound , args .tis_upper_bound , metric_prefix = "tis_"
231+ )
195232 elif args .tis_mode == "mask" :
196- weights , modified_mask = mask (weights , loss_mask , metrics , tis_lower_bound , args .tis_upper_bound )
233+ weights , modified_mask = mask (
234+ weights , loss_mask , metrics , tis_lower_bound , args .tis_upper_bound , metric_prefix = "tis_"
235+ )
197236 else :
198237 raise ValueError (f"Unsupported tis_mode: { args .tis_mode } " )
199238
@@ -212,14 +251,18 @@ def compute_log_ratio(raw_log_diff: torch.Tensor, mask: torch.Tensor, level: str
212251 rs_weights = torch .exp (log_ratio_safe_rs )
213252
214253 # Apply mask-based rejection sampling
215- _ , modified_mask = mask (rs_weights , modified_mask , metrics , rs_lower_bound , rs_upper_bound )
254+ _ , modified_mask = mask (
255+ rs_weights , modified_mask , metrics , rs_lower_bound , rs_upper_bound , metric_prefix = "rs_"
256+ )
216257
217258 # Veto on raw per-token ratios (sequence-wise rejection)
218259 if args .rs_veto_threshold is not None :
219- veto_mask = calculate_veto_mask (raw_log_ratio_diff , loss_mask , args .rs_veto_threshold , metrics )
260+ veto_mask = calculate_veto_mask (
261+ raw_log_ratio_diff , loss_mask , args .rs_veto_threshold , metrics , metric_prefix = "rs_"
262+ )
220263 modified_mask = modified_mask * veto_mask
221264
222- metrics_append (metrics , "ratio_mean_after_tis " , weights )
265+ metrics_append (metrics , "is_ratio_mean_after_tis_rs " , weights )
223266
224267 weights = weights .detach ()
225268 modified_mask = modified_mask .detach ()
@@ -253,6 +296,14 @@ def compute_log_ratio(raw_log_diff: torch.Tensor, mask: torch.Tensor, level: str
253296 for w in all_weights :
254297 metrics_append (metrics , "batch_norm_factor" , torch .ones_like (w ))
255298
299+ # Final weight stats (after optional batch normalization).
300+ # NOTE: These are expanded to token-shape so that the existing mean-reducer can aggregate them.
301+ for w , m in zip (all_weights , loss_masks , strict = False ):
302+ m = m .float ()
303+ metrics_append (metrics , "is_ratio_mean_final" , masked_mean (w , m , expand = True ))
304+ metrics_append (metrics , "is_ratio_min_final" , masked_min (w , m , expand = True ))
305+ metrics_append (metrics , "is_ratio_max_final" , masked_max (w , m , expand = True ))
306+
256307 return all_weights , all_modified_masks , metrics
257308
258309
@@ -280,6 +331,9 @@ def compute_mis_weights_with_cp(
280331 modified_masks: List of modified response masks with rejection applied (one per sequence).
281332 is_metrics: The metrics for the importance sampling weights, a dict of flattened tensors.
282333 """
334+ # Lazy import to avoid importing Megatron dependencies when only `compute_mis_weights` is used.
335+ from slime .backends .megatron_utils .cp_utils import all_gather_with_cp , slice_log_prob_with_cp
336+
283337 # Gather cp slice from other cp ranks
284338 full_rollout_log_probs = [
285339 all_gather_with_cp (log_prob , total_length , response_length )
0 commit comments