@@ -82,9 +82,18 @@ def calculate_kl_loss(
8282 ref_logprob : torch .Tensor ,
8383 response_mask : torch .Tensor ,
8484 loss_agg_mode : str ,
85+ old_logprob : Optional [torch .Tensor ] = None ,
8586 ) -> Tuple [torch .Tensor , Dict ]:
86- """Compute KL loss."""
87- kl = self .calculate_kl (logprob , ref_logprob )
87+ """Compute KL loss.
88+
89+ Args:
90+ logprob: Log probabilities from current policy
91+ ref_logprob: Log probabilities from reference policy
92+ response_mask: Mask for valid response tokens
93+ loss_agg_mode: Loss aggregation mode
94+ old_logprob: Log probabilities from old policy (for importance sampling)
95+ """
96+ kl = self .calculate_kl (logprob , ref_logprob , old_logprob )
8897 kl_loss = aggregate_loss (kl , response_mask , loss_agg_mode = loss_agg_mode )
8998 metrics = {
9099 "kl_loss" : kl_loss .detach ().item (),
@@ -93,8 +102,19 @@ def calculate_kl_loss(
93102 return kl_loss * self .kl_coef , metrics
94103
95104 @abstractmethod
96- def calculate_kl (self , logprob : torch .Tensor , ref_logprob : torch .Tensor ) -> torch .Tensor :
97- """Compute KL divergence between logprob and ref_logprob."""
105+ def calculate_kl (
106+ self ,
107+ logprob : torch .Tensor ,
108+ ref_logprob : torch .Tensor ,
109+ old_logprob : Optional [torch .Tensor ] = None ,
110+ ) -> torch .Tensor :
111+ """Compute KL divergence between logprob and ref_logprob.
112+
113+ Args:
114+ logprob: Log probabilities from current policy
115+ ref_logprob: Log probabilities from reference policy
116+ old_logprob: Log probabilities from old policy (for importance sampling)
117+ """
98118
99119 @classmethod
100120 def default_args (cls ):
@@ -108,7 +128,12 @@ class DummyKLFn(KLFn):
108128 Dummy KL function.
109129 """
110130
111- def calculate_kl (self , logprob : torch .Tensor , ref_logprob : torch .Tensor ) -> torch .Tensor :
131+ def calculate_kl (
132+ self ,
133+ logprob : torch .Tensor ,
134+ ref_logprob : torch .Tensor ,
135+ old_logprob : Optional [torch .Tensor ] = None ,
136+ ) -> torch .Tensor :
112137 return torch .zeros_like (logprob )
113138
114139 def apply_kl_penalty_to_reward (self , experiences : Any ) -> Tuple [Any , Dict ]:
@@ -121,6 +146,7 @@ def calculate_kl_loss(
121146 ref_logprob : torch .Tensor ,
122147 response_mask : torch .Tensor ,
123148 loss_agg_mode : str ,
149+ old_logprob : Optional [torch .Tensor ] = None ,
124150 ) -> Tuple [torch .Tensor , Dict ]:
125151 # return a zero tensor
126152 return torch .tensor (0.0 ), {}
@@ -132,7 +158,12 @@ class K1Fn(KLFn):
132158 KL K1 function.
133159 """
134160
135- def calculate_kl (self , logprob : torch .Tensor , ref_logprob : torch .Tensor ) -> torch .Tensor :
161+ def calculate_kl (
162+ self ,
163+ logprob : torch .Tensor ,
164+ ref_logprob : torch .Tensor ,
165+ old_logprob : Optional [torch .Tensor ] = None ,
166+ ) -> torch .Tensor :
136167 return logprob - ref_logprob
137168
138169
@@ -142,7 +173,12 @@ class K2Fn(KLFn):
142173 KL K2 function.
143174 """
144175
145- def calculate_kl (self , logprob : torch .Tensor , ref_logprob : torch .Tensor ) -> torch .Tensor :
176+ def calculate_kl (
177+ self ,
178+ logprob : torch .Tensor ,
179+ ref_logprob : torch .Tensor ,
180+ old_logprob : Optional [torch .Tensor ] = None ,
181+ ) -> torch .Tensor :
146182 return (logprob - ref_logprob ).square () * 0.5
147183
148184
@@ -152,7 +188,12 @@ class K3Fn(KLFn):
152188 KL K3 function.
153189 """
154190
155- def calculate_kl (self , logprob : torch .Tensor , ref_logprob : torch .Tensor ) -> torch .Tensor :
191+ def calculate_kl (
192+ self ,
193+ logprob : torch .Tensor ,
194+ ref_logprob : torch .Tensor ,
195+ old_logprob : Optional [torch .Tensor ] = None ,
196+ ) -> torch .Tensor :
156197 logr = ref_logprob - logprob
157198 return logr .exp () - 1 - logr
158199
@@ -163,7 +204,12 @@ class LowVarKLFn(KLFn):
163204 Low Variance KL function.
164205 """
165206
166- def calculate_kl (self , logprob : torch .Tensor , ref_logprob : torch .Tensor ) -> torch .Tensor :
207+ def calculate_kl (
208+ self ,
209+ logprob : torch .Tensor ,
210+ ref_logprob : torch .Tensor ,
211+ old_logprob : Optional [torch .Tensor ] = None ,
212+ ) -> torch .Tensor :
167213 kl = ref_logprob - logprob
168214 kl = torch .clamp (kl , min = - 20 , max = 20 )
169215 ratio = torch .exp (kl )
@@ -177,5 +223,66 @@ class AbsFn(KLFn):
177223 KL Abs function.
178224 """
179225
180- def calculate_kl (self , logprob : torch .Tensor , ref_logprob : torch .Tensor ) -> torch .Tensor :
226+ def calculate_kl (
227+ self ,
228+ logprob : torch .Tensor ,
229+ ref_logprob : torch .Tensor ,
230+ old_logprob : Optional [torch .Tensor ] = None ,
231+ ) -> torch .Tensor :
181232 return torch .abs (logprob - ref_logprob )
233+
234+
235+ @KL_FN .register_module ("corrected_k3" )
236+ class CorrectedK3Fn (KLFn ):
237+ """
238+ Corrected K3 function with importance sampling.
239+
240+ This method applies importance sampling correction to the standard K3 KL divergence.
241+ The corrected KL is computed as:
242+
243+ KL_corrected = (π_θ / π_old) * KL_standard(π_ref || π_θ)
244+
245+ where:
246+ - π_θ: current policy
247+ - π_old: old policy (from rollout)
248+ - π_ref: reference policy
249+ - KL_standard: exp(log(π_ref/π_θ)) - log(π_ref/π_θ) - 1
250+
251+ If old_logprob is not provided, it falls back to standard K3.
252+ """
253+
254+ def calculate_kl (
255+ self ,
256+ logprob : torch .Tensor ,
257+ ref_logprob : torch .Tensor ,
258+ old_logprob : Optional [torch .Tensor ] = None ,
259+ ) -> torch .Tensor :
260+ """Compute corrected K3 KL divergence with importance sampling.
261+
262+ Args:
263+ logprob: Log probabilities from current policy (log π_θ)
264+ ref_logprob: Log probabilities from reference policy (log π_ref)
265+ old_logprob: Log probabilities from old policy (log π_old), optional
266+
267+ Returns:
268+ KL divergence tensor with same shape as input
269+ """
270+ # Standard K3 KL term: exp(log_ratio) - log_ratio - 1
271+ # where log_ratio = log(π_ref / π_θ) = ref_logprob - logprob
272+ logr = ref_logprob - logprob
273+ kl_term = logr .exp () - 1 - logr
274+
275+ if old_logprob is None :
276+ # Fall back to standard K3 if old_logprob is not provided
277+ return kl_term
278+
279+ # Compute importance sampling ratio: π_θ / π_old
280+ log_ratio_is = logprob - old_logprob
281+ ratio_is = log_ratio_is .exp ()
282+ # Clamp ratio for numerical stability, range [0, 2]
283+ ratio_is = torch .clamp (ratio_is , min = 0.0 , max = 2.0 )
284+
285+ # Corrected KL with importance sampling
286+ corrected_kl = ratio_is * kl_term
287+
288+ return corrected_kl
0 commit comments