1919ppo_value_data = namedtuple ('ppo_value_data' , ['value_new' , 'value_old' , 'return_' , 'weight' ])
2020ppo_loss = namedtuple ('ppo_loss' , ['policy_loss' , 'value_loss' , 'entropy_loss' ])
2121ppo_policy_loss = namedtuple ('ppo_policy_loss' , ['policy_loss' , 'entropy_loss' ])
22- ppo_info = namedtuple ('ppo_info' , ['approx_kl' , 'clipfrac' ])
22+ ppo_info = namedtuple ('ppo_info' , ['approx_kl' , 'clipfrac' , 'kl_div' ])
2323
2424
2525def shape_fn_ppo (args , kwargs ):
@@ -46,7 +46,8 @@ def ppo_error(
4646 data : namedtuple ,
4747 clip_ratio : float = 0.2 ,
4848 use_value_clip : bool = True ,
49- dual_clip : Optional [float ] = None
49+ dual_clip : Optional [float ] = None ,
50+ kl_type : str = 'k1'
5051) -> Tuple [namedtuple , namedtuple ]:
5152 """
5253 Overview:
@@ -57,6 +58,7 @@ def ppo_error(
5758 - use_value_clip (:obj:`bool`): whether to use clip in value loss with the same ratio as policy
5859 - dual_clip (:obj:`float`): a parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf),\
5960 defaults to 5.0, if you don't want to use it, set this parameter to None
61+ - kl_type (:obj:`str`): which kl loss to use, default set to 'approx'
6062 Returns:
6163 - ppo_loss (:obj:`namedtuple`): the ppo loss item, all of them are the differentiable 0-dim tensor
6264 - ppo_info (:obj:`namedtuple`): the ppo optim information for monitoring, all of them are Python scalar
@@ -97,7 +99,7 @@ def ppo_error(
9799 )
98100 logit_new , logit_old , action , value_new , value_old , adv , return_ , weight = data
99101 policy_data = ppo_policy_data (logit_new , logit_old , action , adv , weight )
100- policy_output , policy_info = ppo_policy_error (policy_data , clip_ratio , dual_clip )
102+ policy_output , policy_info = ppo_policy_error (policy_data , clip_ratio , dual_clip , kl_type = kl_type )
101103 value_data = ppo_value_data (value_new , value_old , return_ , weight )
102104 value_loss = ppo_value_error (value_data , clip_ratio , use_value_clip )
103105
@@ -108,7 +110,8 @@ def ppo_policy_error(
108110 data : namedtuple ,
109111 clip_ratio : float = 0.2 ,
110112 dual_clip : Optional [float ] = None ,
111- entropy_bonus : bool = True
113+ entropy_bonus : bool = True ,
114+ kl_type : str = 'k1'
112115) -> Tuple [namedtuple , namedtuple ]:
113116 """
114117 Overview:
@@ -119,6 +122,7 @@ def ppo_policy_error(
119122 - dual_clip (:obj:`float`): A parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf), \
120123 defaults to 5.0, if you don't want to use it, set this parameter to None
121124 - entropy_bonus (:obj:`bool`): Whether to use entropy bonus, defaults to True. LLM RLHF usually does not use it.
125+ - kl_type (:obj:`str`): which kl loss to use, default set to 'k1'
122126 Returns:
123127 - ppo_policy_loss (:obj:`namedtuple`): the ppo policy loss item, all of them are the differentiable 0-dim tensor
124128 - ppo_info (:obj:`namedtuple`): the ppo optim information for monitoring, all of them are Python scalar
@@ -180,7 +184,18 @@ def ppo_policy_error(
180184 approx_kl = (logp_old - logp_new ).mean ().item ()
181185 clipped = ratio .gt (1 + clip_ratio ) | ratio .lt (1 - clip_ratio )
182186 clipfrac = torch .as_tensor (clipped ).float ().mean ().item ()
183- return ppo_policy_loss (policy_loss , entropy_loss ), ppo_info (approx_kl , clipfrac )
187+
188+ logr = logp_old - logp_new
189+ if kl_type == 'k1' :
190+ kl_div = logr .mean ()
191+ elif kl_type == 'k2' :
192+ kl_div = (logr ** 2 / 2 ).mean ()
193+ elif kl_type == 'k3' :
194+ kl_div = (torch .exp (- logr ) - 1 + logr ).mean ()
195+ else :
196+ raise ValueError (f"Unknown kl_type: { kl_type } " )
197+
198+ return ppo_policy_loss (policy_loss , entropy_loss ), ppo_info (approx_kl , clipfrac , kl_div )
184199
185200
186201def ppo_value_error (
@@ -232,7 +247,8 @@ def ppo_error_continuous(
232247 data : namedtuple ,
233248 clip_ratio : float = 0.2 ,
234249 use_value_clip : bool = True ,
235- dual_clip : Optional [float ] = None
250+ dual_clip : Optional [float ] = None ,
251+ kl_type : str = 'k1'
236252) -> Tuple [namedtuple , namedtuple ]:
237253 """
238254 Overview:
@@ -243,6 +259,7 @@ def ppo_error_continuous(
243259 - use_value_clip (:obj:`bool`): whether to use clip in value loss with the same ratio as policy
244260 - dual_clip (:obj:`float`): a parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf),\
245261 defaults to 5.0, if you don't want to use it, set this parameter to None
262+ - kl_type (:obj:`str`): which kl loss to use, default set to 'k1'
246263 Returns:
247264 - ppo_loss (:obj:`namedtuple`): the ppo loss item, all of them are the differentiable 0-dim tensor
248265 - ppo_info (:obj:`namedtuple`): the ppo optim information for monitoring, all of them are Python scalar
@@ -314,12 +331,25 @@ def ppo_error_continuous(
314331 else :
315332 value_loss = 0.5 * ((return_ - value_new ).pow (2 ) * weight ).mean ()
316333
317- return ppo_loss (policy_loss , value_loss , entropy_loss ), ppo_info (approx_kl , clipfrac )
334+ logr = logp_old - logp_new
335+ if kl_type == 'k1' :
336+ kl_div = logr .mean ()
337+ elif kl_type == 'k2' :
338+ kl_div = (logr ** 2 / 2 ).mean ()
339+ elif kl_type == 'k3' :
340+ kl_div = (torch .exp (- logr ) - 1 + logr ).mean ()
341+ else :
342+ raise ValueError (f"Unknown kl_type: { kl_type } " )
343+
344+ return ppo_loss (policy_loss , value_loss , entropy_loss ), ppo_info (approx_kl , clipfrac , kl_div )
318345
319346
320- def ppo_policy_error_continuous (data : namedtuple ,
321- clip_ratio : float = 0.2 ,
322- dual_clip : Optional [float ] = None ) -> Tuple [namedtuple , namedtuple ]:
347+ def ppo_policy_error_continuous (
348+ data : namedtuple ,
349+ clip_ratio : float = 0.2 ,
350+ dual_clip : Optional [float ] = None ,
351+ kl_type : str = 'k1'
352+ ) -> Tuple [namedtuple , namedtuple ]:
323353 """
324354 Overview:
325355 Implementation of Proximal Policy Optimization (arXiv:1707.06347) with dual_clip
@@ -328,6 +358,7 @@ def ppo_policy_error_continuous(data: namedtuple,
328358 - clip_ratio (:obj:`float`): the ppo clip ratio for the constraint of policy update, defaults to 0.2
329359 - dual_clip (:obj:`float`): a parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf),\
330360 defaults to 5.0, if you don't want to use it, set this parameter to None
361+ - kl_type (:obj:`str`): which kl loss to use, default set to 'k1'
331362 Returns:
332363 - ppo_loss (:obj:`namedtuple`): the ppo loss item, all of them are the differentiable 0-dim tensor
333364 - ppo_info (:obj:`namedtuple`): the ppo optim information for monitoring, all of them are Python scalar
@@ -377,4 +408,15 @@ def ppo_policy_error_continuous(data: namedtuple,
377408 approx_kl = (logp_old - logp_new ).mean ().item ()
378409 clipped = ratio .gt (1 + clip_ratio ) | ratio .lt (1 - clip_ratio )
379410 clipfrac = torch .as_tensor (clipped ).float ().mean ().item ()
380- return ppo_policy_loss (policy_loss , entropy_loss ), ppo_info (approx_kl , clipfrac )
411+
412+ logr = logp_old - logp_new
413+ if kl_type == 'k1' :
414+ kl_div = logr .mean ()
415+ elif kl_type == 'k2' :
416+ kl_div = (logr ** 2 / 2 ).mean ()
417+ elif kl_type == 'k3' :
418+ kl_div = (torch .exp (- logr ) - 1 + logr ).mean ()
419+ else :
420+ raise ValueError (f"Unknown kl_type: { kl_type } " )
421+
422+ return ppo_policy_loss (policy_loss , entropy_loss ), ppo_info (approx_kl , clipfrac , kl_div )
0 commit comments