Skip to content

Commit b9ff286

Browse files
authored
Add corrected kl with importance sampling (#419)
1 parent 3861859 commit b9ff286

File tree

4 files changed

+250
-10
lines changed

4 files changed

+250
-10
lines changed

tests/algorithm/kl_fn_test.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
# -*- coding: utf-8 -*-
2+
"""Test for KL functions"""
3+
4+
import unittest
5+
6+
import torch
7+
8+
from trinity.algorithm.kl_fn.kl_fn import KL_FN
9+
10+
11+
class KLFnTest(unittest.TestCase):
12+
def setUp(self):
13+
seed = 42
14+
torch.manual_seed(seed)
15+
torch.cuda.manual_seed(seed)
16+
torch.cuda.manual_seed_all(seed)
17+
torch.backends.cudnn.deterministic = True
18+
torch.backends.cudnn.benchmark = False
19+
20+
shape = (4, 10)
21+
self.logprob = 2 * torch.rand(shape) - 1
22+
self.ref_logprob = 2 * torch.rand(shape) - 1
23+
self.old_logprob = 2 * torch.rand(shape) - 1
24+
self.response_mask = torch.rand(shape) > 0.5
25+
26+
def test_k1_kl_fn(self):
27+
kl_fn_cls = KL_FN.get("k1")
28+
kl_fn = kl_fn_cls(kl_coef=0.01)
29+
kl = kl_fn.calculate_kl(self.logprob, self.ref_logprob)
30+
expected_kl = self.logprob - self.ref_logprob
31+
self.assertTrue(torch.allclose(kl, expected_kl))
32+
33+
def test_k2_kl_fn(self):
34+
kl_fn_cls = KL_FN.get("k2")
35+
kl_fn = kl_fn_cls(kl_coef=0.01)
36+
kl = kl_fn.calculate_kl(self.logprob, self.ref_logprob)
37+
expected_kl = (self.logprob - self.ref_logprob).square() * 0.5
38+
self.assertTrue(torch.allclose(kl, expected_kl))
39+
40+
def test_k3_kl_fn(self):
41+
kl_fn_cls = KL_FN.get("k3")
42+
kl_fn = kl_fn_cls(kl_coef=0.01)
43+
kl = kl_fn.calculate_kl(self.logprob, self.ref_logprob)
44+
logr = self.ref_logprob - self.logprob
45+
expected_kl = logr.exp() - 1 - logr
46+
self.assertTrue(torch.allclose(kl, expected_kl))
47+
48+
def test_abs_kl_fn(self):
49+
kl_fn_cls = KL_FN.get("abs")
50+
kl_fn = kl_fn_cls(kl_coef=0.01)
51+
kl = kl_fn.calculate_kl(self.logprob, self.ref_logprob)
52+
expected_kl = torch.abs(self.logprob - self.ref_logprob)
53+
self.assertTrue(torch.allclose(kl, expected_kl))
54+
55+
def test_low_var_kl_fn(self):
56+
kl_fn_cls = KL_FN.get("low_var_kl")
57+
kl_fn = kl_fn_cls(kl_coef=0.01)
58+
kl = kl_fn.calculate_kl(self.logprob, self.ref_logprob)
59+
kl_intermediate = self.ref_logprob - self.logprob
60+
kl_intermediate = torch.clamp(kl_intermediate, min=-20, max=20)
61+
ratio = torch.exp(kl_intermediate)
62+
expected_kl = torch.clamp((ratio - kl_intermediate - 1).contiguous(), min=-10, max=10)
63+
self.assertTrue(torch.allclose(kl, expected_kl))
64+
65+
def test_dummy_kl_fn(self):
66+
kl_fn_cls = KL_FN.get("none")
67+
kl_fn = kl_fn_cls(kl_coef=0.01)
68+
kl = kl_fn.calculate_kl(self.logprob, self.ref_logprob)
69+
expected_kl = torch.zeros_like(self.logprob)
70+
self.assertTrue(torch.allclose(kl, expected_kl))
71+
72+
def test_corrected_k3_fallback(self):
73+
k3_fn = KL_FN.get("k3")(kl_coef=0.01)
74+
corrected_k3_fn = KL_FN.get("corrected_k3")(kl_coef=0.01)
75+
kl_standard = k3_fn.calculate_kl(self.logprob, self.ref_logprob)
76+
kl_corrected_no_old = corrected_k3_fn.calculate_kl(
77+
self.logprob, self.ref_logprob, old_logprob=None
78+
)
79+
self.assertTrue(torch.allclose(kl_standard, kl_corrected_no_old))
80+
81+
def test_corrected_k3_with_old_logprob(self):
82+
corrected_k3_fn = KL_FN.get("corrected_k3")(kl_coef=0.01)
83+
kl_corrected = corrected_k3_fn.calculate_kl(
84+
self.logprob, self.ref_logprob, self.old_logprob
85+
)
86+
logr = self.ref_logprob - self.logprob
87+
kl_standard = logr.exp() - 1 - logr
88+
log_ratio_is = self.logprob - self.old_logprob
89+
ratio_is = log_ratio_is.exp()
90+
ratio_is = torch.clamp(ratio_is, min=0.0, max=2.0)
91+
expected_kl = ratio_is * kl_standard
92+
self.assertTrue(torch.allclose(kl_corrected, expected_kl))
93+
94+
def test_corrected_k3_same_policy(self):
95+
k3_fn = KL_FN.get("k3")(kl_coef=0.01)
96+
corrected_k3_fn = KL_FN.get("corrected_k3")(kl_coef=0.01)
97+
kl_standard = k3_fn.calculate_kl(self.logprob, self.ref_logprob)
98+
kl_corrected = corrected_k3_fn.calculate_kl(self.logprob, self.ref_logprob, self.logprob)
99+
self.assertTrue(torch.allclose(kl_standard, kl_corrected, rtol=1e-4, atol=1e-6))
100+
101+
def test_corrected_k3_loss(self):
102+
corrected_k3_fn = KL_FN.get("corrected_k3")(kl_coef=0.01)
103+
kl_loss, metrics = corrected_k3_fn.calculate_kl_loss(
104+
logprob=self.logprob,
105+
ref_logprob=self.ref_logprob,
106+
response_mask=self.response_mask,
107+
loss_agg_mode="token-mean",
108+
old_logprob=self.old_logprob,
109+
)
110+
self.assertEqual(kl_loss.dim(), 0)
111+
self.assertIn("kl_loss", metrics)
112+
self.assertIn("kl_coef", metrics)
113+
self.assertEqual(metrics["kl_coef"], 0.01)
114+
115+
def test_kl_loss_aggregation_modes(self):
116+
corrected_k3_fn = KL_FN.get("corrected_k3")(kl_coef=0.01)
117+
kl_loss_mean, _ = corrected_k3_fn.calculate_kl_loss(
118+
logprob=self.logprob,
119+
ref_logprob=self.ref_logprob,
120+
response_mask=self.response_mask,
121+
loss_agg_mode="token-mean",
122+
old_logprob=self.old_logprob,
123+
)
124+
kl_loss_sum, _ = corrected_k3_fn.calculate_kl_loss(
125+
logprob=self.logprob,
126+
ref_logprob=self.ref_logprob,
127+
response_mask=self.response_mask,
128+
loss_agg_mode="seq-mean-token-sum",
129+
old_logprob=self.old_logprob,
130+
)
131+
self.assertGreater(kl_loss_sum.item(), kl_loss_mean.item())

trinity/algorithm/kl_fn/kl_fn.py

Lines changed: 117 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,18 @@ def calculate_kl_loss(
8282
ref_logprob: torch.Tensor,
8383
response_mask: torch.Tensor,
8484
loss_agg_mode: str,
85+
old_logprob: Optional[torch.Tensor] = None,
8586
) -> Tuple[torch.Tensor, Dict]:
86-
"""Compute KL loss."""
87-
kl = self.calculate_kl(logprob, ref_logprob)
87+
"""Compute KL loss.
88+
89+
Args:
90+
logprob: Log probabilities from current policy
91+
ref_logprob: Log probabilities from reference policy
92+
response_mask: Mask for valid response tokens
93+
loss_agg_mode: Loss aggregation mode
94+
old_logprob: Log probabilities from old policy (for importance sampling)
95+
"""
96+
kl = self.calculate_kl(logprob, ref_logprob, old_logprob)
8897
kl_loss = aggregate_loss(kl, response_mask, loss_agg_mode=loss_agg_mode)
8998
metrics = {
9099
"kl_loss": kl_loss.detach().item(),
@@ -93,8 +102,19 @@ def calculate_kl_loss(
93102
return kl_loss * self.kl_coef, metrics
94103

95104
@abstractmethod
96-
def calculate_kl(self, logprob: torch.Tensor, ref_logprob: torch.Tensor) -> torch.Tensor:
97-
"""Compute KL divergence between logprob and ref_logprob."""
105+
def calculate_kl(
106+
self,
107+
logprob: torch.Tensor,
108+
ref_logprob: torch.Tensor,
109+
old_logprob: Optional[torch.Tensor] = None,
110+
) -> torch.Tensor:
111+
"""Compute KL divergence between logprob and ref_logprob.
112+
113+
Args:
114+
logprob: Log probabilities from current policy
115+
ref_logprob: Log probabilities from reference policy
116+
old_logprob: Log probabilities from old policy (for importance sampling)
117+
"""
98118

99119
@classmethod
100120
def default_args(cls):
@@ -108,7 +128,12 @@ class DummyKLFn(KLFn):
108128
Dummy KL function.
109129
"""
110130

111-
def calculate_kl(self, logprob: torch.Tensor, ref_logprob: torch.Tensor) -> torch.Tensor:
131+
def calculate_kl(
132+
self,
133+
logprob: torch.Tensor,
134+
ref_logprob: torch.Tensor,
135+
old_logprob: Optional[torch.Tensor] = None,
136+
) -> torch.Tensor:
112137
return torch.zeros_like(logprob)
113138

114139
def apply_kl_penalty_to_reward(self, experiences: Any) -> Tuple[Any, Dict]:
@@ -121,6 +146,7 @@ def calculate_kl_loss(
121146
ref_logprob: torch.Tensor,
122147
response_mask: torch.Tensor,
123148
loss_agg_mode: str,
149+
old_logprob: Optional[torch.Tensor] = None,
124150
) -> Tuple[torch.Tensor, Dict]:
125151
# return a zero tensor
126152
return torch.tensor(0.0), {}
@@ -132,7 +158,12 @@ class K1Fn(KLFn):
132158
KL K1 function.
133159
"""
134160

135-
def calculate_kl(self, logprob: torch.Tensor, ref_logprob: torch.Tensor) -> torch.Tensor:
161+
def calculate_kl(
162+
self,
163+
logprob: torch.Tensor,
164+
ref_logprob: torch.Tensor,
165+
old_logprob: Optional[torch.Tensor] = None,
166+
) -> torch.Tensor:
136167
return logprob - ref_logprob
137168

138169

@@ -142,7 +173,12 @@ class K2Fn(KLFn):
142173
KL K2 function.
143174
"""
144175

145-
def calculate_kl(self, logprob: torch.Tensor, ref_logprob: torch.Tensor) -> torch.Tensor:
176+
def calculate_kl(
177+
self,
178+
logprob: torch.Tensor,
179+
ref_logprob: torch.Tensor,
180+
old_logprob: Optional[torch.Tensor] = None,
181+
) -> torch.Tensor:
146182
return (logprob - ref_logprob).square() * 0.5
147183

148184

@@ -152,7 +188,12 @@ class K3Fn(KLFn):
152188
KL K3 function.
153189
"""
154190

155-
def calculate_kl(self, logprob: torch.Tensor, ref_logprob: torch.Tensor) -> torch.Tensor:
191+
def calculate_kl(
192+
self,
193+
logprob: torch.Tensor,
194+
ref_logprob: torch.Tensor,
195+
old_logprob: Optional[torch.Tensor] = None,
196+
) -> torch.Tensor:
156197
logr = ref_logprob - logprob
157198
return logr.exp() - 1 - logr
158199

@@ -163,7 +204,12 @@ class LowVarKLFn(KLFn):
163204
Low Variance KL function.
164205
"""
165206

166-
def calculate_kl(self, logprob: torch.Tensor, ref_logprob: torch.Tensor) -> torch.Tensor:
207+
def calculate_kl(
208+
self,
209+
logprob: torch.Tensor,
210+
ref_logprob: torch.Tensor,
211+
old_logprob: Optional[torch.Tensor] = None,
212+
) -> torch.Tensor:
167213
kl = ref_logprob - logprob
168214
kl = torch.clamp(kl, min=-20, max=20)
169215
ratio = torch.exp(kl)
@@ -177,5 +223,66 @@ class AbsFn(KLFn):
177223
KL Abs function.
178224
"""
179225

180-
def calculate_kl(self, logprob: torch.Tensor, ref_logprob: torch.Tensor) -> torch.Tensor:
226+
def calculate_kl(
227+
self,
228+
logprob: torch.Tensor,
229+
ref_logprob: torch.Tensor,
230+
old_logprob: Optional[torch.Tensor] = None,
231+
) -> torch.Tensor:
181232
return torch.abs(logprob - ref_logprob)
233+
234+
235+
@KL_FN.register_module("corrected_k3")
236+
class CorrectedK3Fn(KLFn):
237+
"""
238+
Corrected K3 function with importance sampling.
239+
240+
This method applies importance sampling correction to the standard K3 KL divergence.
241+
The corrected KL is computed as:
242+
243+
KL_corrected = (π_θ / π_old) * KL_standard(π_ref || π_θ)
244+
245+
where:
246+
- π_θ: current policy
247+
- π_old: old policy (from rollout)
248+
- π_ref: reference policy
249+
- KL_standard: exp(log(π_ref/π_θ)) - log(π_ref/π_θ) - 1
250+
251+
If old_logprob is not provided, it falls back to standard K3.
252+
"""
253+
254+
def calculate_kl(
255+
self,
256+
logprob: torch.Tensor,
257+
ref_logprob: torch.Tensor,
258+
old_logprob: Optional[torch.Tensor] = None,
259+
) -> torch.Tensor:
260+
"""Compute corrected K3 KL divergence with importance sampling.
261+
262+
Args:
263+
logprob: Log probabilities from current policy (log π_θ)
264+
ref_logprob: Log probabilities from reference policy (log π_ref)
265+
old_logprob: Log probabilities from old policy (log π_old), optional
266+
267+
Returns:
268+
KL divergence tensor with same shape as input
269+
"""
270+
# Standard K3 KL term: exp(log_ratio) - log_ratio - 1
271+
# where log_ratio = log(π_ref / π_θ) = ref_logprob - logprob
272+
logr = ref_logprob - logprob
273+
kl_term = logr.exp() - 1 - logr
274+
275+
if old_logprob is None:
276+
# Fall back to standard K3 if old_logprob is not provided
277+
return kl_term
278+
279+
# Compute importance sampling ratio: π_θ / π_old
280+
log_ratio_is = logprob - old_logprob
281+
ratio_is = log_ratio_is.exp()
282+
# Clamp ratio for numerical stability, range [0, 2]
283+
ratio_is = torch.clamp(ratio_is, min=0.0, max=2.0)
284+
285+
# Corrected KL with importance sampling
286+
corrected_kl = ratio_is * kl_term
287+
288+
return corrected_kl

trinity/trainer/verl/dp_actor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ def update_policy(self, data: DataProto): # noqa: C901
162162
ref_logprob=model_inputs.get("ref_log_prob", None),
163163
response_mask=response_mask,
164164
loss_agg_mode=self.loss_agg_mode,
165+
old_logprob=model_inputs.get("old_log_probs", None),
165166
)
166167
prefix_metrics(
167168
src_metrics=kl_loss_metrics,

trinity/trainer/verl/megatron_actor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ def loss_func(output, data, meta_info):
210210
ref_logprob=data.get("ref_log_prob", None),
211211
response_mask=response_mask,
212212
loss_agg_mode=self.loss_agg_mode,
213+
old_logprob=data.get("old_log_probs", None),
213214
)
214215
prefix_metrics(
215216
src_metrics=kl_loss_metrics,

0 commit comments

Comments
 (0)