-
Notifications
You must be signed in to change notification settings - Fork 224
[skyrl-train] Refactor TIS to use more comprehensive off policy correction config #849
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
f033e65
0b236fe
3f3b759
29efd6f
1520157
45a59c2
ce01bb2
abac800
2dc7364
349369d
f3f7054
c45c130
63d38c5
7e83c10
cf042fc
cef7121
9485bdd
9e11eda
c06747c
0b5ebfd
0697957
6b9e1e4
46b6fe5
d72d9c6
db76d01
ac0659c
7ddb85f
6c8d084
08c3625
2bca41f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -552,6 +552,152 @@ def _safe_exp_delta(delta: torch.Tensor, clip: float = 20.0, out_dtype=None) -> | |
| return y.to(out_dtype or delta.dtype) | ||
|
|
||
|
|
||
| def compute_tis_ratio( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These are great! Can we put them into a separate file? Our In long term we could break |
||
| old_log_probs: torch.Tensor, | ||
| rollout_logprobs: torch.Tensor, | ||
| loss_mask: torch.Tensor, | ||
| tis_ratio_type: str, | ||
| rollout_corr: DictConfig, | ||
| ) -> torch.Tensor: | ||
| """ | ||
| Compute truncated importance sampling (TIS) ratio for rollout correction. | ||
|
|
||
| Args: | ||
| old_log_probs: Log probabilities from the old policy (before update). | ||
| rollout_logprobs: Log probabilities from the rollout policy. | ||
| loss_mask: Mask indicating valid tokens. | ||
| tis_ratio_type: Type of TIS ratio ("token" or "sequence"). | ||
| rollout_corr: Rollout correction config containing cap values. | ||
|
|
||
| Returns: | ||
| TIS ratio tensor to multiply with the loss. | ||
|
|
||
| Reference: https://github.com/szrlee/verl/blob/yingru/rollout_correction/docs/advance/rollout_corr_math.md | ||
| """ | ||
| # Compute token-level importance ratio: pi_old / pi_rollout | ||
| # In log space: old_log_probs - rollout_logprobs | ||
| token_tis_log_ratio = old_log_probs - rollout_logprobs | ||
| token_tis_ratio = _safe_exp_delta(token_tis_log_ratio, clip=20.0, out_dtype=old_log_probs.dtype) | ||
|
|
||
| if tis_ratio_type == "token": | ||
| token_tis_ratio_cap = rollout_corr.token_tis_ratio_cap_high | ||
| return torch.clamp(token_tis_ratio, max=token_tis_ratio_cap) | ||
| elif tis_ratio_type == "sequence": | ||
| # Compute sequence-level importance ratio as product of token ratios (sum of log ratios) | ||
| seq_tis_log_ratio = (token_tis_log_ratio * loss_mask).sum(dim=-1, keepdim=True) | ||
| seq_tis_ratio = _safe_exp_delta(seq_tis_log_ratio, clip=20.0, out_dtype=old_log_probs.dtype) | ||
| seq_tis_ratio_cap = rollout_corr.sequence_tis_ratio_cap_high | ||
| return torch.clamp(seq_tis_ratio, max=seq_tis_ratio_cap) | ||
| else: | ||
| raise ValueError(f"Unknown tis_ratio_type: {tis_ratio_type}") | ||
|
|
||
|
|
||
| def compute_rejection_mask( | ||
| old_log_probs: torch.Tensor, | ||
| rollout_logprobs: torch.Tensor, | ||
| loss_mask: torch.Tensor, | ||
| rejection_mask_type: str, | ||
| rollout_corr: DictConfig, | ||
| ) -> torch.Tensor: | ||
| """ | ||
| Compute rejection mask for rollout correction. | ||
|
|
||
| This masks out sequences with importance ratios that fall outside acceptable bounds, | ||
| helping to filter out off-policy samples that may destabilize training. | ||
|
|
||
| Args: | ||
| old_log_probs: Log probabilities from the old policy (before update). | ||
| rollout_logprobs: Log probabilities from the rollout policy. | ||
| loss_mask: Mask indicating valid tokens. | ||
| rejection_mask_type: Type of rejection mask ("geometric" or "sequence"). | ||
| rollout_corr: Rollout correction config containing cap values. | ||
|
|
||
| Returns: | ||
| Rejection mask tensor (float) to multiply with the loss. | ||
|
|
||
| Reference: https://github.com/szrlee/verl/blob/yingru/rollout_correction/docs/advance/rollout_corr_math.md | ||
| """ | ||
| # Compute token-level importance ratio | ||
| token_tis_log_ratio = old_log_probs - rollout_logprobs | ||
| token_tis_ratio = _safe_exp_delta(token_tis_log_ratio, clip=20.0, out_dtype=old_log_probs.dtype) | ||
|
|
||
| if rejection_mask_type == "geometric": | ||
| # Compute geometric mean of importance ratios per sequence | ||
| num_tokens = loss_mask.sum(dim=-1, keepdim=True).clamp(min=1.0) | ||
| seq_tis_log_ratio = (token_tis_log_ratio * loss_mask).sum(dim=-1, keepdim=True) | ||
| geo_mean_ratio = _safe_exp_delta(seq_tis_log_ratio / num_tokens, clip=20.0, out_dtype=old_log_probs.dtype) | ||
| geo_cap_high = rollout_corr.geo_rejection_mask_ratio_cap_high | ||
| geo_cap_low = rollout_corr.geo_rejection_mask_ratio_cap_low | ||
| geo_rejection_mask = (geo_mean_ratio >= geo_cap_low) & (geo_mean_ratio <= geo_cap_high) | ||
| return geo_rejection_mask.float() | ||
| elif rejection_mask_type == "sequence": | ||
| # Mask out sequences with product of importance ratios outside the cap | ||
| seq_tis_log_ratio = (token_tis_log_ratio * loss_mask).sum(dim=-1, keepdim=True) | ||
| seq_tis_ratio = _safe_exp_delta(seq_tis_log_ratio, clip=20.0, out_dtype=old_log_probs.dtype) | ||
| seq_cap_high = rollout_corr.sequence_rejection_mask_ratio_cap_high | ||
| seq_cap_low = rollout_corr.sequence_rejection_mask_ratio_cap_low | ||
| # Also check per-token bounds | ||
| token_mask_low = rollout_corr.sequence_mask_low | ||
| token_mask_high = rollout_corr.sequence_mask_high | ||
| token_in_bounds = (token_tis_ratio >= token_mask_low) & (token_tis_ratio <= token_mask_high) | ||
| # A sequence is valid if all tokens are in bounds (considering only masked positions) | ||
| all_tokens_valid = (token_in_bounds | (loss_mask == 0)).all(dim=-1, keepdim=True) | ||
| seq_in_bounds = (seq_tis_ratio >= seq_cap_low) & (seq_tis_ratio <= seq_cap_high) | ||
| seq_rejection_mask = seq_in_bounds & all_tokens_valid | ||
| return seq_rejection_mask.float() | ||
| else: | ||
| raise ValueError(f"Unknown rejection_mask_type: {rejection_mask_type}") | ||
|
|
||
|
|
||
| def apply_rollout_correction( | ||
| loss: torch.Tensor, | ||
| old_log_probs: torch.Tensor, | ||
| rollout_logprobs: torch.Tensor, | ||
| loss_mask: torch.Tensor, | ||
| rollout_corr: DictConfig, | ||
| ) -> torch.Tensor: | ||
| """ | ||
| Apply rollout correction to the loss using TIS ratio and rejection mask. | ||
|
|
||
| This is a convenience function that combines compute_tis_ratio and compute_rejection_mask. | ||
|
|
||
| Args: | ||
| loss: The loss tensor to correct. | ||
| old_log_probs: Log probabilities from the old policy (before update). | ||
| rollout_logprobs: Log probabilities from the rollout policy. | ||
| loss_mask: Mask indicating valid tokens. | ||
| rollout_corr: Rollout correction config. | ||
|
|
||
| Returns: | ||
| Corrected loss tensor. | ||
| """ | ||
| tis_ratio_type = rollout_corr.tis_ratio_type | ||
| rejection_mask_type = rollout_corr.rejection_mask_type | ||
|
|
||
| # Check if TIS ratio correction is enabled | ||
| apply_tis = tis_ratio_type is not None and tis_ratio_type != "null" | ||
| # Check if rejection mask is enabled | ||
| apply_rejection = rejection_mask_type is not None and rejection_mask_type != "null" | ||
|
|
||
| # Early return if no correction needed | ||
| if not apply_tis and not apply_rejection: | ||
| return loss | ||
|
|
||
| # Apply TIS ratio if enabled | ||
| if apply_tis: | ||
| tis_ratio = compute_tis_ratio(old_log_probs, rollout_logprobs, loss_mask, tis_ratio_type, rollout_corr) | ||
| loss = loss * tis_ratio | ||
|
|
||
| # Apply rejection mask if enabled | ||
| if apply_rejection: | ||
| rejection_mask = compute_rejection_mask( | ||
| old_log_probs, rollout_logprobs, loss_mask, rejection_mask_type, rollout_corr | ||
| ) | ||
| loss = loss * rejection_mask | ||
|
|
||
| return loss | ||
|
|
||
erictang000 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| @register_policy_loss(PolicyLossType.REGULAR) | ||
| @register_policy_loss(PolicyLossType.DUAL_CLIP) | ||
| def ppo_policy_loss( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The return is typed as |
||
|
|
@@ -581,6 +727,7 @@ def ppo_policy_loss( | |
| clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1) | ||
| loss = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1) | ||
|
|
||
| # Legacy TIS support (deprecated) | ||
| if config.use_tis: | ||
| from loguru import logger as logger_ | ||
|
|
||
|
|
@@ -590,6 +737,11 @@ def ppo_policy_loss( | |
| tis_imp_ratio = torch.clamp(tis_imp_ratio, max=config.tis_imp_ratio_cap) | ||
| loss = loss * tis_imp_ratio | ||
|
|
||
| # New rollout correction support | ||
| rollout_corr = config.get("rollout_correction", None) | ||
| if rollout_corr is not None and rollout_logprobs is not None: | ||
| loss = apply_rollout_correction(loss, old_log_probs, rollout_logprobs, loss_mask, rollout_corr) | ||
|
|
||
| loss = reduce_loss(loss, loss_mask, loss_reduction, config.max_seq_len) | ||
| return loss, clip_ratio | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.