diff --git a/tests/algorithm/kl_fn_test.py b/tests/algorithm/kl_fn_test.py new file mode 100644 index 0000000000..f3771c5f7c --- /dev/null +++ b/tests/algorithm/kl_fn_test.py @@ -0,0 +1,131 @@ +# -*- coding: utf-8 -*- +"""Test for KL functions""" + +import unittest + +import torch + +from trinity.algorithm.kl_fn.kl_fn import KL_FN + + +class KLFnTest(unittest.TestCase): + def setUp(self): + seed = 42 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + shape = (4, 10) + self.logprob = 2 * torch.rand(shape) - 1 + self.ref_logprob = 2 * torch.rand(shape) - 1 + self.old_logprob = 2 * torch.rand(shape) - 1 + self.response_mask = torch.rand(shape) > 0.5 + + def test_k1_kl_fn(self): + kl_fn_cls = KL_FN.get("k1") + kl_fn = kl_fn_cls(kl_coef=0.01) + kl = kl_fn.calculate_kl(self.logprob, self.ref_logprob) + expected_kl = self.logprob - self.ref_logprob + self.assertTrue(torch.allclose(kl, expected_kl)) + + def test_k2_kl_fn(self): + kl_fn_cls = KL_FN.get("k2") + kl_fn = kl_fn_cls(kl_coef=0.01) + kl = kl_fn.calculate_kl(self.logprob, self.ref_logprob) + expected_kl = (self.logprob - self.ref_logprob).square() * 0.5 + self.assertTrue(torch.allclose(kl, expected_kl)) + + def test_k3_kl_fn(self): + kl_fn_cls = KL_FN.get("k3") + kl_fn = kl_fn_cls(kl_coef=0.01) + kl = kl_fn.calculate_kl(self.logprob, self.ref_logprob) + logr = self.ref_logprob - self.logprob + expected_kl = logr.exp() - 1 - logr + self.assertTrue(torch.allclose(kl, expected_kl)) + + def test_abs_kl_fn(self): + kl_fn_cls = KL_FN.get("abs") + kl_fn = kl_fn_cls(kl_coef=0.01) + kl = kl_fn.calculate_kl(self.logprob, self.ref_logprob) + expected_kl = torch.abs(self.logprob - self.ref_logprob) + self.assertTrue(torch.allclose(kl, expected_kl)) + + def test_low_var_kl_fn(self): + kl_fn_cls = KL_FN.get("low_var_kl") + kl_fn = kl_fn_cls(kl_coef=0.01) + kl = kl_fn.calculate_kl(self.logprob, self.ref_logprob) + kl_intermediate = self.ref_logprob - self.logprob + kl_intermediate = torch.clamp(kl_intermediate, min=-20, max=20) + ratio = torch.exp(kl_intermediate) + expected_kl = torch.clamp((ratio - kl_intermediate - 1).contiguous(), min=-10, max=10) + self.assertTrue(torch.allclose(kl, expected_kl)) + + def test_dummy_kl_fn(self): + kl_fn_cls = KL_FN.get("none") + kl_fn = kl_fn_cls(kl_coef=0.01) + kl = kl_fn.calculate_kl(self.logprob, self.ref_logprob) + expected_kl = torch.zeros_like(self.logprob) + self.assertTrue(torch.allclose(kl, expected_kl)) + + def test_corrected_k3_fallback(self): + k3_fn = KL_FN.get("k3")(kl_coef=0.01) + corrected_k3_fn = KL_FN.get("corrected_k3")(kl_coef=0.01) + kl_standard = k3_fn.calculate_kl(self.logprob, self.ref_logprob) + kl_corrected_no_old = corrected_k3_fn.calculate_kl( + self.logprob, self.ref_logprob, old_logprob=None + ) + self.assertTrue(torch.allclose(kl_standard, kl_corrected_no_old)) + + def test_corrected_k3_with_old_logprob(self): + corrected_k3_fn = KL_FN.get("corrected_k3")(kl_coef=0.01) + kl_corrected = corrected_k3_fn.calculate_kl( + self.logprob, self.ref_logprob, self.old_logprob + ) + logr = self.ref_logprob - self.logprob + kl_standard = logr.exp() - 1 - logr + log_ratio_is = self.logprob - self.old_logprob + ratio_is = log_ratio_is.exp() + ratio_is = torch.clamp(ratio_is, min=0.0, max=2.0) + expected_kl = ratio_is * kl_standard + self.assertTrue(torch.allclose(kl_corrected, expected_kl)) + + def test_corrected_k3_same_policy(self): + k3_fn = KL_FN.get("k3")(kl_coef=0.01) + corrected_k3_fn = KL_FN.get("corrected_k3")(kl_coef=0.01) + kl_standard = k3_fn.calculate_kl(self.logprob, self.ref_logprob) + kl_corrected = corrected_k3_fn.calculate_kl(self.logprob, self.ref_logprob, self.logprob) + self.assertTrue(torch.allclose(kl_standard, kl_corrected, rtol=1e-4, atol=1e-6)) + + def test_corrected_k3_loss(self): + corrected_k3_fn = KL_FN.get("corrected_k3")(kl_coef=0.01) + kl_loss, metrics = corrected_k3_fn.calculate_kl_loss( + logprob=self.logprob, + ref_logprob=self.ref_logprob, + response_mask=self.response_mask, + loss_agg_mode="token-mean", + old_logprob=self.old_logprob, + ) + self.assertEqual(kl_loss.dim(), 0) + self.assertIn("kl_loss", metrics) + self.assertIn("kl_coef", metrics) + self.assertEqual(metrics["kl_coef"], 0.01) + + def test_kl_loss_aggregation_modes(self): + corrected_k3_fn = KL_FN.get("corrected_k3")(kl_coef=0.01) + kl_loss_mean, _ = corrected_k3_fn.calculate_kl_loss( + logprob=self.logprob, + ref_logprob=self.ref_logprob, + response_mask=self.response_mask, + loss_agg_mode="token-mean", + old_logprob=self.old_logprob, + ) + kl_loss_sum, _ = corrected_k3_fn.calculate_kl_loss( + logprob=self.logprob, + ref_logprob=self.ref_logprob, + response_mask=self.response_mask, + loss_agg_mode="seq-mean-token-sum", + old_logprob=self.old_logprob, + ) + self.assertGreater(kl_loss_sum.item(), kl_loss_mean.item()) diff --git a/trinity/algorithm/kl_fn/kl_fn.py b/trinity/algorithm/kl_fn/kl_fn.py index 49c59f2367..54e1639608 100644 --- a/trinity/algorithm/kl_fn/kl_fn.py +++ b/trinity/algorithm/kl_fn/kl_fn.py @@ -82,9 +82,18 @@ def calculate_kl_loss( ref_logprob: torch.Tensor, response_mask: torch.Tensor, loss_agg_mode: str, + old_logprob: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Dict]: - """Compute KL loss.""" - kl = self.calculate_kl(logprob, ref_logprob) + """Compute KL loss. + + Args: + logprob: Log probabilities from current policy + ref_logprob: Log probabilities from reference policy + response_mask: Mask for valid response tokens + loss_agg_mode: Loss aggregation mode + old_logprob: Log probabilities from old policy (for importance sampling) + """ + kl = self.calculate_kl(logprob, ref_logprob, old_logprob) kl_loss = aggregate_loss(kl, response_mask, loss_agg_mode=loss_agg_mode) metrics = { "kl_loss": kl_loss.detach().item(), @@ -93,8 +102,19 @@ def calculate_kl_loss( return kl_loss * self.kl_coef, metrics @abstractmethod - def calculate_kl(self, logprob: torch.Tensor, ref_logprob: torch.Tensor) -> torch.Tensor: - """Compute KL divergence between logprob and ref_logprob.""" + def calculate_kl( + self, + logprob: torch.Tensor, + ref_logprob: torch.Tensor, + old_logprob: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Compute KL divergence between logprob and ref_logprob. + + Args: + logprob: Log probabilities from current policy + ref_logprob: Log probabilities from reference policy + old_logprob: Log probabilities from old policy (for importance sampling) + """ @classmethod def default_args(cls): @@ -108,7 +128,12 @@ class DummyKLFn(KLFn): Dummy KL function. """ - def calculate_kl(self, logprob: torch.Tensor, ref_logprob: torch.Tensor) -> torch.Tensor: + def calculate_kl( + self, + logprob: torch.Tensor, + ref_logprob: torch.Tensor, + old_logprob: Optional[torch.Tensor] = None, + ) -> torch.Tensor: return torch.zeros_like(logprob) def apply_kl_penalty_to_reward(self, experiences: Any) -> Tuple[Any, Dict]: @@ -121,6 +146,7 @@ def calculate_kl_loss( ref_logprob: torch.Tensor, response_mask: torch.Tensor, loss_agg_mode: str, + old_logprob: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Dict]: # return a zero tensor return torch.tensor(0.0), {} @@ -132,7 +158,12 @@ class K1Fn(KLFn): KL K1 function. """ - def calculate_kl(self, logprob: torch.Tensor, ref_logprob: torch.Tensor) -> torch.Tensor: + def calculate_kl( + self, + logprob: torch.Tensor, + ref_logprob: torch.Tensor, + old_logprob: Optional[torch.Tensor] = None, + ) -> torch.Tensor: return logprob - ref_logprob @@ -142,7 +173,12 @@ class K2Fn(KLFn): KL K2 function. """ - def calculate_kl(self, logprob: torch.Tensor, ref_logprob: torch.Tensor) -> torch.Tensor: + def calculate_kl( + self, + logprob: torch.Tensor, + ref_logprob: torch.Tensor, + old_logprob: Optional[torch.Tensor] = None, + ) -> torch.Tensor: return (logprob - ref_logprob).square() * 0.5 @@ -152,7 +188,12 @@ class K3Fn(KLFn): KL K3 function. """ - def calculate_kl(self, logprob: torch.Tensor, ref_logprob: torch.Tensor) -> torch.Tensor: + def calculate_kl( + self, + logprob: torch.Tensor, + ref_logprob: torch.Tensor, + old_logprob: Optional[torch.Tensor] = None, + ) -> torch.Tensor: logr = ref_logprob - logprob return logr.exp() - 1 - logr @@ -163,7 +204,12 @@ class LowVarKLFn(KLFn): Low Variance KL function. """ - def calculate_kl(self, logprob: torch.Tensor, ref_logprob: torch.Tensor) -> torch.Tensor: + def calculate_kl( + self, + logprob: torch.Tensor, + ref_logprob: torch.Tensor, + old_logprob: Optional[torch.Tensor] = None, + ) -> torch.Tensor: kl = ref_logprob - logprob kl = torch.clamp(kl, min=-20, max=20) ratio = torch.exp(kl) @@ -177,5 +223,66 @@ class AbsFn(KLFn): KL Abs function. """ - def calculate_kl(self, logprob: torch.Tensor, ref_logprob: torch.Tensor) -> torch.Tensor: + def calculate_kl( + self, + logprob: torch.Tensor, + ref_logprob: torch.Tensor, + old_logprob: Optional[torch.Tensor] = None, + ) -> torch.Tensor: return torch.abs(logprob - ref_logprob) + + +@KL_FN.register_module("corrected_k3") +class CorrectedK3Fn(KLFn): + """ + Corrected K3 function with importance sampling. + + This method applies importance sampling correction to the standard K3 KL divergence. + The corrected KL is computed as: + + KL_corrected = (π_θ / π_old) * KL_standard(π_ref || π_θ) + + where: + - π_θ: current policy + - π_old: old policy (from rollout) + - π_ref: reference policy + - KL_standard: exp(log(π_ref/π_θ)) - log(π_ref/π_θ) - 1 + + If old_logprob is not provided, it falls back to standard K3. + """ + + def calculate_kl( + self, + logprob: torch.Tensor, + ref_logprob: torch.Tensor, + old_logprob: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Compute corrected K3 KL divergence with importance sampling. + + Args: + logprob: Log probabilities from current policy (log π_θ) + ref_logprob: Log probabilities from reference policy (log π_ref) + old_logprob: Log probabilities from old policy (log π_old), optional + + Returns: + KL divergence tensor with same shape as input + """ + # Standard K3 KL term: exp(log_ratio) - log_ratio - 1 + # where log_ratio = log(π_ref / π_θ) = ref_logprob - logprob + logr = ref_logprob - logprob + kl_term = logr.exp() - 1 - logr + + if old_logprob is None: + # Fall back to standard K3 if old_logprob is not provided + return kl_term + + # Compute importance sampling ratio: π_θ / π_old + log_ratio_is = logprob - old_logprob + ratio_is = log_ratio_is.exp() + # Clamp ratio for numerical stability, range [0, 2] + ratio_is = torch.clamp(ratio_is, min=0.0, max=2.0) + + # Corrected KL with importance sampling + corrected_kl = ratio_is * kl_term + + return corrected_kl diff --git a/trinity/trainer/verl/dp_actor.py b/trinity/trainer/verl/dp_actor.py index 0d571ed497..2d1258b9cb 100644 --- a/trinity/trainer/verl/dp_actor.py +++ b/trinity/trainer/verl/dp_actor.py @@ -162,6 +162,7 @@ def update_policy(self, data: DataProto): # noqa: C901 ref_logprob=model_inputs.get("ref_log_prob", None), response_mask=response_mask, loss_agg_mode=self.loss_agg_mode, + old_logprob=model_inputs.get("old_log_probs", None), ) prefix_metrics( src_metrics=kl_loss_metrics, diff --git a/trinity/trainer/verl/megatron_actor.py b/trinity/trainer/verl/megatron_actor.py index 4e9183609f..48849c9037 100644 --- a/trinity/trainer/verl/megatron_actor.py +++ b/trinity/trainer/verl/megatron_actor.py @@ -210,6 +210,7 @@ def loss_func(output, data, meta_info): ref_logprob=data.get("ref_log_prob", None), response_mask=response_mask, loss_agg_mode=self.loss_agg_mode, + old_logprob=data.get("old_log_probs", None), ) prefix_metrics( src_metrics=kl_loss_metrics,