Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 131 additions & 0 deletions tests/algorithm/kl_fn_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# -*- coding: utf-8 -*-
"""Test for KL functions"""

import unittest

import torch

from trinity.algorithm.kl_fn.kl_fn import KL_FN


class KLFnTest(unittest.TestCase):
def setUp(self):
seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

shape = (4, 10)
self.logprob = 2 * torch.rand(shape) - 1
self.ref_logprob = 2 * torch.rand(shape) - 1
self.old_logprob = 2 * torch.rand(shape) - 1
self.response_mask = torch.rand(shape) > 0.5

def test_k1_kl_fn(self):
kl_fn_cls = KL_FN.get("k1")
kl_fn = kl_fn_cls(kl_coef=0.01)
kl = kl_fn.calculate_kl(self.logprob, self.ref_logprob)
expected_kl = self.logprob - self.ref_logprob
self.assertTrue(torch.allclose(kl, expected_kl))

def test_k2_kl_fn(self):
kl_fn_cls = KL_FN.get("k2")
kl_fn = kl_fn_cls(kl_coef=0.01)
kl = kl_fn.calculate_kl(self.logprob, self.ref_logprob)
expected_kl = (self.logprob - self.ref_logprob).square() * 0.5
self.assertTrue(torch.allclose(kl, expected_kl))

def test_k3_kl_fn(self):
kl_fn_cls = KL_FN.get("k3")
kl_fn = kl_fn_cls(kl_coef=0.01)
kl = kl_fn.calculate_kl(self.logprob, self.ref_logprob)
logr = self.ref_logprob - self.logprob
expected_kl = logr.exp() - 1 - logr
self.assertTrue(torch.allclose(kl, expected_kl))

def test_abs_kl_fn(self):
kl_fn_cls = KL_FN.get("abs")
kl_fn = kl_fn_cls(kl_coef=0.01)
kl = kl_fn.calculate_kl(self.logprob, self.ref_logprob)
expected_kl = torch.abs(self.logprob - self.ref_logprob)
self.assertTrue(torch.allclose(kl, expected_kl))

def test_low_var_kl_fn(self):
kl_fn_cls = KL_FN.get("low_var_kl")
kl_fn = kl_fn_cls(kl_coef=0.01)
kl = kl_fn.calculate_kl(self.logprob, self.ref_logprob)
kl_intermediate = self.ref_logprob - self.logprob
kl_intermediate = torch.clamp(kl_intermediate, min=-20, max=20)
ratio = torch.exp(kl_intermediate)
expected_kl = torch.clamp((ratio - kl_intermediate - 1).contiguous(), min=-10, max=10)
self.assertTrue(torch.allclose(kl, expected_kl))

def test_dummy_kl_fn(self):
kl_fn_cls = KL_FN.get("none")
kl_fn = kl_fn_cls(kl_coef=0.01)
kl = kl_fn.calculate_kl(self.logprob, self.ref_logprob)
expected_kl = torch.zeros_like(self.logprob)
self.assertTrue(torch.allclose(kl, expected_kl))

def test_corrected_k3_fallback(self):
k3_fn = KL_FN.get("k3")(kl_coef=0.01)
corrected_k3_fn = KL_FN.get("corrected_k3")(kl_coef=0.01)
kl_standard = k3_fn.calculate_kl(self.logprob, self.ref_logprob)
kl_corrected_no_old = corrected_k3_fn.calculate_kl(
self.logprob, self.ref_logprob, old_logprob=None
)
self.assertTrue(torch.allclose(kl_standard, kl_corrected_no_old))

def test_corrected_k3_with_old_logprob(self):
corrected_k3_fn = KL_FN.get("corrected_k3")(kl_coef=0.01)
kl_corrected = corrected_k3_fn.calculate_kl(
self.logprob, self.ref_logprob, self.old_logprob
)
logr = self.ref_logprob - self.logprob
kl_standard = logr.exp() - 1 - logr
log_ratio_is = self.logprob - self.old_logprob
ratio_is = log_ratio_is.exp()
ratio_is = torch.clamp(ratio_is, min=0.0, max=2.0)
expected_kl = ratio_is * kl_standard
self.assertTrue(torch.allclose(kl_corrected, expected_kl))

def test_corrected_k3_same_policy(self):
k3_fn = KL_FN.get("k3")(kl_coef=0.01)
corrected_k3_fn = KL_FN.get("corrected_k3")(kl_coef=0.01)
kl_standard = k3_fn.calculate_kl(self.logprob, self.ref_logprob)
kl_corrected = corrected_k3_fn.calculate_kl(self.logprob, self.ref_logprob, self.logprob)
self.assertTrue(torch.allclose(kl_standard, kl_corrected, rtol=1e-4, atol=1e-6))

def test_corrected_k3_loss(self):
corrected_k3_fn = KL_FN.get("corrected_k3")(kl_coef=0.01)
kl_loss, metrics = corrected_k3_fn.calculate_kl_loss(
logprob=self.logprob,
ref_logprob=self.ref_logprob,
response_mask=self.response_mask,
loss_agg_mode="token-mean",
old_logprob=self.old_logprob,
)
self.assertEqual(kl_loss.dim(), 0)
self.assertIn("kl_loss", metrics)
self.assertIn("kl_coef", metrics)
self.assertEqual(metrics["kl_coef"], 0.01)

def test_kl_loss_aggregation_modes(self):
corrected_k3_fn = KL_FN.get("corrected_k3")(kl_coef=0.01)
kl_loss_mean, _ = corrected_k3_fn.calculate_kl_loss(
logprob=self.logprob,
ref_logprob=self.ref_logprob,
response_mask=self.response_mask,
loss_agg_mode="token-mean",
old_logprob=self.old_logprob,
)
kl_loss_sum, _ = corrected_k3_fn.calculate_kl_loss(
logprob=self.logprob,
ref_logprob=self.ref_logprob,
response_mask=self.response_mask,
loss_agg_mode="seq-mean-token-sum",
old_logprob=self.old_logprob,
)
self.assertGreater(kl_loss_sum.item(), kl_loss_mean.item())
127 changes: 117 additions & 10 deletions trinity/algorithm/kl_fn/kl_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,18 @@ def calculate_kl_loss(
ref_logprob: torch.Tensor,
response_mask: torch.Tensor,
loss_agg_mode: str,
old_logprob: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Dict]:
"""Compute KL loss."""
kl = self.calculate_kl(logprob, ref_logprob)
"""Compute KL loss.

Args:
logprob: Log probabilities from current policy
ref_logprob: Log probabilities from reference policy
response_mask: Mask for valid response tokens
loss_agg_mode: Loss aggregation mode
old_logprob: Log probabilities from old policy (for importance sampling)
"""
kl = self.calculate_kl(logprob, ref_logprob, old_logprob)
kl_loss = aggregate_loss(kl, response_mask, loss_agg_mode=loss_agg_mode)
metrics = {
"kl_loss": kl_loss.detach().item(),
Expand All @@ -93,8 +102,19 @@ def calculate_kl_loss(
return kl_loss * self.kl_coef, metrics

@abstractmethod
def calculate_kl(self, logprob: torch.Tensor, ref_logprob: torch.Tensor) -> torch.Tensor:
"""Compute KL divergence between logprob and ref_logprob."""
def calculate_kl(
self,
logprob: torch.Tensor,
ref_logprob: torch.Tensor,
old_logprob: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Compute KL divergence between logprob and ref_logprob.

Args:
logprob: Log probabilities from current policy
ref_logprob: Log probabilities from reference policy
old_logprob: Log probabilities from old policy (for importance sampling)
"""

@classmethod
def default_args(cls):
Expand All @@ -108,7 +128,12 @@ class DummyKLFn(KLFn):
Dummy KL function.
"""

def calculate_kl(self, logprob: torch.Tensor, ref_logprob: torch.Tensor) -> torch.Tensor:
def calculate_kl(
self,
logprob: torch.Tensor,
ref_logprob: torch.Tensor,
old_logprob: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return torch.zeros_like(logprob)

def apply_kl_penalty_to_reward(self, experiences: Any) -> Tuple[Any, Dict]:
Expand All @@ -121,6 +146,7 @@ def calculate_kl_loss(
ref_logprob: torch.Tensor,
response_mask: torch.Tensor,
loss_agg_mode: str,
old_logprob: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Dict]:
# return a zero tensor
return torch.tensor(0.0), {}
Expand All @@ -132,7 +158,12 @@ class K1Fn(KLFn):
KL K1 function.
"""

def calculate_kl(self, logprob: torch.Tensor, ref_logprob: torch.Tensor) -> torch.Tensor:
def calculate_kl(
self,
logprob: torch.Tensor,
ref_logprob: torch.Tensor,
old_logprob: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return logprob - ref_logprob


Expand All @@ -142,7 +173,12 @@ class K2Fn(KLFn):
KL K2 function.
"""

def calculate_kl(self, logprob: torch.Tensor, ref_logprob: torch.Tensor) -> torch.Tensor:
def calculate_kl(
self,
logprob: torch.Tensor,
ref_logprob: torch.Tensor,
old_logprob: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return (logprob - ref_logprob).square() * 0.5


Expand All @@ -152,7 +188,12 @@ class K3Fn(KLFn):
KL K3 function.
"""

def calculate_kl(self, logprob: torch.Tensor, ref_logprob: torch.Tensor) -> torch.Tensor:
def calculate_kl(
self,
logprob: torch.Tensor,
ref_logprob: torch.Tensor,
old_logprob: Optional[torch.Tensor] = None,
) -> torch.Tensor:
logr = ref_logprob - logprob
return logr.exp() - 1 - logr

Expand All @@ -163,7 +204,12 @@ class LowVarKLFn(KLFn):
Low Variance KL function.
"""

def calculate_kl(self, logprob: torch.Tensor, ref_logprob: torch.Tensor) -> torch.Tensor:
def calculate_kl(
self,
logprob: torch.Tensor,
ref_logprob: torch.Tensor,
old_logprob: Optional[torch.Tensor] = None,
) -> torch.Tensor:
kl = ref_logprob - logprob
kl = torch.clamp(kl, min=-20, max=20)
ratio = torch.exp(kl)
Expand All @@ -177,5 +223,66 @@ class AbsFn(KLFn):
KL Abs function.
"""

def calculate_kl(self, logprob: torch.Tensor, ref_logprob: torch.Tensor) -> torch.Tensor:
def calculate_kl(
self,
logprob: torch.Tensor,
ref_logprob: torch.Tensor,
old_logprob: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return torch.abs(logprob - ref_logprob)


@KL_FN.register_module("corrected_k3")
class CorrectedK3Fn(KLFn):
"""
Corrected K3 function with importance sampling.

This method applies importance sampling correction to the standard K3 KL divergence.
The corrected KL is computed as:

KL_corrected = (π_θ / π_old) * KL_standard(π_ref || π_θ)

where:
- π_θ: current policy
- π_old: old policy (from rollout)
- π_ref: reference policy
- KL_standard: exp(log(π_ref/π_θ)) - log(π_ref/π_θ) - 1

If old_logprob is not provided, it falls back to standard K3.
"""

def calculate_kl(
self,
logprob: torch.Tensor,
ref_logprob: torch.Tensor,
old_logprob: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Compute corrected K3 KL divergence with importance sampling.

Args:
logprob: Log probabilities from current policy (log π_θ)
ref_logprob: Log probabilities from reference policy (log π_ref)
old_logprob: Log probabilities from old policy (log π_old), optional

Returns:
KL divergence tensor with same shape as input
"""
# Standard K3 KL term: exp(log_ratio) - log_ratio - 1
# where log_ratio = log(π_ref / π_θ) = ref_logprob - logprob
logr = ref_logprob - logprob
kl_term = logr.exp() - 1 - logr

if old_logprob is None:
# Fall back to standard K3 if old_logprob is not provided
return kl_term

# Compute importance sampling ratio: π_θ / π_old
log_ratio_is = logprob - old_logprob
ratio_is = log_ratio_is.exp()
# Clamp ratio for numerical stability, range [0, 2]
ratio_is = torch.clamp(ratio_is, min=0.0, max=2.0)

# Corrected KL with importance sampling
corrected_kl = ratio_is * kl_term

return corrected_kl
1 change: 1 addition & 0 deletions trinity/trainer/verl/dp_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def update_policy(self, data: DataProto): # noqa: C901
ref_logprob=model_inputs.get("ref_log_prob", None),
response_mask=response_mask,
loss_agg_mode=self.loss_agg_mode,
old_logprob=model_inputs.get("old_log_probs", None),
)
prefix_metrics(
src_metrics=kl_loss_metrics,
Expand Down
1 change: 1 addition & 0 deletions trinity/trainer/verl/megatron_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ def loss_func(output, data, meta_info):
ref_logprob=data.get("ref_log_prob", None),
response_mask=response_mask,
loss_agg_mode=self.loss_agg_mode,
old_logprob=data.get("old_log_probs", None),
)
prefix_metrics(
src_metrics=kl_loss_metrics,
Expand Down