66from ding .hpc_rl import hpc_wrapper
77
88ppo_data = namedtuple (
9- 'ppo_data' , ['logit_new' , 'logit_old' , 'action' , 'value_new' , 'value_old' , 'adv' , 'return_' , 'weight' ]
9+ 'ppo_data' ,
10+ ['logit_new' , 'logit_old' , 'action' , 'value_new' , 'value_old' , 'adv' , 'return_' , 'weight' , 'logit_pretrained' ]
1011)
1112ppo_data_continuous = namedtuple (
12- 'ppo_data_continuous' ,
13- ['mu_sigma_new' , 'mu_sigma_old' , 'action' , 'value_new' , 'value_old' , 'adv' , 'return_' , 'weight' ]
13+ 'ppo_data_continuous' , [
14+ 'mu_sigma_new' , 'mu_sigma_old' , 'action' , 'value_new' , 'value_old' , 'adv' , 'return_' , 'weight' ,
15+ 'logit_pretrained'
16+ ]
17+ )
18+ ppo_policy_data = namedtuple (
19+ 'ppo_policy_data' , ['logit_new' , 'logit_old' , 'action' , 'adv' , 'weight' , 'logit_pretrained' ]
1420)
15- ppo_policy_data = namedtuple ('ppo_policy_data' , ['logit_new' , 'logit_old' , 'action' , 'adv' , 'weight' ])
1621ppo_policy_data_continuous = namedtuple (
17- 'ppo_policy_data_continuous' , ['mu_sigma_new' , 'mu_sigma_old' , 'action' , 'adv' , 'weight' ]
22+ 'ppo_policy_data_continuous' , ['mu_sigma_new' , 'mu_sigma_old' , 'action' , 'adv' , 'weight' , 'logit_pretrained' ]
1823)
1924ppo_value_data = namedtuple ('ppo_value_data' , ['value_new' , 'value_old' , 'return_' , 'weight' ])
20- ppo_loss = namedtuple ('ppo_loss' , ['policy_loss' , 'value_loss' , 'entropy_loss' ])
21- ppo_policy_loss = namedtuple ('ppo_policy_loss' , ['policy_loss' , 'entropy_loss' ])
25+ ppo_loss = namedtuple ('ppo_loss' , ['policy_loss' , 'value_loss' , 'entropy_loss' , 'kl_div' ])
26+ ppo_policy_loss = namedtuple ('ppo_policy_loss' , ['policy_loss' , 'entropy_loss' , 'kl_div' ])
2227ppo_info = namedtuple ('ppo_info' , ['approx_kl' , 'clipfrac' ])
2328
2429
30+ def calculate_kl_div (log_ratio : torch .Tensor , kl_type : str ) -> torch .Tensor :
31+ """
32+ Overview:
33+ Calculate different Monte-Carlo estimators for KL-divergence KL(q, p) = E_q[log(q/p)],
34+ where q is the current policy and p is the pretrained policy.
35+ The implementation is based on John Schulman's blog post "Approximating KL Divergence".
36+ Reference: http://joschu.net/blog/kl-approx.html
37+ Arguments:
38+ - log_ratio (:obj:`torch.Tensor`): The log-ratio of probabilities, which should be
39+ log(q/p) = logp_new - logp_pretrained.
40+ - kl_type (:obj:`str`): The type of KL divergence estimator to use.
41+ - 'k1': The standard, unbiased but high-variance estimator: `E_q[log(q/p)]`.
42+ - 'k2': A biased, low-variance estimator from a second-order approximation: `E_q[1/2 * (log(p/q))^2]`.
43+ - 'k3': An unbiased, low-variance estimator: `E_q[(p/q - 1) - log(p/q)]`.
44+ Returns:
45+ - kl_div (:obj:`torch.Tensor`): The calculated KL divergence estimate.
46+ """
47+ if kl_type == 'k1' :
48+ return log_ratio .mean ()
49+ elif kl_type == 'k2' :
50+ return (log_ratio ** 2 / 2 ).mean ()
51+ elif kl_type == 'k3' :
52+ return (torch .exp (- log_ratio ) - 1 + log_ratio ).mean ()
53+ else :
54+ raise ValueError (f"Unknown kl_type: { kl_type } " )
55+
56+
2557def shape_fn_ppo (args , kwargs ):
2658 r"""
2759 Overview:
@@ -46,7 +78,8 @@ def ppo_error(
4678 data : namedtuple ,
4779 clip_ratio : float = 0.2 ,
4880 use_value_clip : bool = True ,
49- dual_clip : Optional [float ] = None
81+ dual_clip : Optional [float ] = None ,
82+ kl_type : str = 'k1'
5083) -> Tuple [namedtuple , namedtuple ]:
5184 """
5285 Overview:
@@ -57,6 +90,7 @@ def ppo_error(
5790 - use_value_clip (:obj:`bool`): whether to use clip in value loss with the same ratio as policy
5891 - dual_clip (:obj:`float`): a parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf),\
5992 defaults to 5.0, if you don't want to use it, set this parameter to None
93+ - kl_type (:obj:`str`): which kl loss to use, default set to 'k1'.
6094 Returns:
6195 - ppo_loss (:obj:`namedtuple`): the ppo loss item, all of them are the differentiable 0-dim tensor
6296 - ppo_info (:obj:`namedtuple`): the ppo optim information for monitoring, all of them are Python scalar
@@ -95,20 +129,23 @@ def ppo_error(
95129 assert dual_clip is None or dual_clip > 1.0 , "dual_clip value must be greater than 1.0, but get value: {}" .format (
96130 dual_clip
97131 )
98- logit_new , logit_old , action , value_new , value_old , adv , return_ , weight = data
99- policy_data = ppo_policy_data (logit_new , logit_old , action , adv , weight )
100- policy_output , policy_info = ppo_policy_error (policy_data , clip_ratio , dual_clip )
132+ logit_new , logit_old , action , value_new , value_old , adv , return_ , weight , logit_pretrained = data
133+ policy_data = ppo_policy_data (logit_new , logit_old , action , adv , weight , logit_pretrained )
134+ policy_output , policy_info = ppo_policy_error (policy_data , clip_ratio , dual_clip , kl_type = kl_type )
101135 value_data = ppo_value_data (value_new , value_old , return_ , weight )
102136 value_loss = ppo_value_error (value_data , clip_ratio , use_value_clip )
103137
104- return ppo_loss (policy_output .policy_loss , value_loss , policy_output .entropy_loss ), policy_info
138+ return ppo_loss (
139+ policy_output .policy_loss , value_loss , policy_output .entropy_loss , policy_output .kl_div
140+ ), policy_info
105141
106142
107143def ppo_policy_error (
108144 data : namedtuple ,
109145 clip_ratio : float = 0.2 ,
110146 dual_clip : Optional [float ] = None ,
111- entropy_bonus : bool = True
147+ entropy_bonus : bool = True ,
148+ kl_type : str = 'k1'
112149) -> Tuple [namedtuple , namedtuple ]:
113150 """
114151 Overview:
@@ -119,6 +156,7 @@ def ppo_policy_error(
119156 - dual_clip (:obj:`float`): A parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf), \
120157 defaults to 5.0, if you don't want to use it, set this parameter to None
121158 - entropy_bonus (:obj:`bool`): Whether to use entropy bonus, defaults to True. LLM RLHF usually does not use it.
159+ - kl_type (:obj:`str`): which kl loss to use, default set to 'k1'.
122160 Returns:
123161 - ppo_policy_loss (:obj:`namedtuple`): the ppo policy loss item, all of them are the differentiable 0-dim tensor
124162 - ppo_info (:obj:`namedtuple`): the ppo optim information for monitoring, all of them are Python scalar
@@ -148,7 +186,7 @@ def ppo_policy_error(
148186 .. note::
149187 For the action mask often used in LLM/VLM, users can set the `weight` to the action mask.
150188 """
151- logit_new , logit_old , action , adv , weight = data
189+ logit_new , logit_old , action , adv , weight , logit_pretrained = data
152190 if weight is None :
153191 weight = torch .ones_like (adv )
154192 dist_new = torch .distributions .categorical .Categorical (logits = logit_new )
@@ -180,7 +218,16 @@ def ppo_policy_error(
180218 approx_kl = (logp_old - logp_new ).mean ().item ()
181219 clipped = ratio .gt (1 + clip_ratio ) | ratio .lt (1 - clip_ratio )
182220 clipfrac = torch .as_tensor (clipped ).float ().mean ().item ()
183- return ppo_policy_loss (policy_loss , entropy_loss ), ppo_info (approx_kl , clipfrac )
221+
222+ if logit_pretrained is not None :
223+ dist_pretrained = torch .distributions .categorical .Categorical (logits = logit_pretrained )
224+ logp_pretrained = dist_pretrained .log_prob (action )
225+ log_ratio = logp_new - logp_pretrained
226+ kl_div = calculate_kl_div (log_ratio , kl_type )
227+ else :
228+ kl_div = 0
229+
230+ return ppo_policy_loss (policy_loss , entropy_loss , kl_div ), ppo_info (approx_kl , clipfrac )
184231
185232
186233def ppo_value_error (
@@ -232,7 +279,8 @@ def ppo_error_continuous(
232279 data : namedtuple ,
233280 clip_ratio : float = 0.2 ,
234281 use_value_clip : bool = True ,
235- dual_clip : Optional [float ] = None
282+ dual_clip : Optional [float ] = None ,
283+ kl_type : str = 'k1'
236284) -> Tuple [namedtuple , namedtuple ]:
237285 """
238286 Overview:
@@ -243,6 +291,7 @@ def ppo_error_continuous(
243291 - use_value_clip (:obj:`bool`): whether to use clip in value loss with the same ratio as policy
244292 - dual_clip (:obj:`float`): a parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf),\
245293 defaults to 5.0, if you don't want to use it, set this parameter to None
294+ - kl_type (:obj:`str`): which kl loss to use, default set to 'k1'.
246295 Returns:
247296 - ppo_loss (:obj:`namedtuple`): the ppo loss item, all of them are the differentiable 0-dim tensor
248297 - ppo_info (:obj:`namedtuple`): the ppo optim information for monitoring, all of them are Python scalar
@@ -281,7 +330,7 @@ def ppo_error_continuous(
281330 assert dual_clip is None or dual_clip > 1.0 , "dual_clip value must be greater than 1.0, but get value: {}" .format (
282331 dual_clip
283332 )
284- mu_sigma_new , mu_sigma_old , action , value_new , value_old , adv , return_ , weight = data
333+ mu_sigma_new , mu_sigma_old , action , value_new , value_old , adv , return_ , weight , logit_pretrained = data
285334 if weight is None :
286335 weight = torch .ones_like (adv )
287336
@@ -314,12 +363,23 @@ def ppo_error_continuous(
314363 else :
315364 value_loss = 0.5 * ((return_ - value_new ).pow (2 ) * weight ).mean ()
316365
317- return ppo_loss (policy_loss , value_loss , entropy_loss ), ppo_info (approx_kl , clipfrac )
366+ if logit_pretrained is not None :
367+ dist_pretrained = Independent (Normal (logit_pretrained ['mu' ], logit_pretrained ['sigma' ]), 1 )
368+ logp_pretrained = dist_pretrained .log_prob (action )
369+ log_ratio = logp_new - logp_pretrained
370+ kl_div = calculate_kl_div (log_ratio , kl_type )
371+ else :
372+ kl_div = 0
373+
374+ return ppo_loss (policy_loss , value_loss , entropy_loss , kl_div ), ppo_info (approx_kl , clipfrac )
318375
319376
320- def ppo_policy_error_continuous (data : namedtuple ,
321- clip_ratio : float = 0.2 ,
322- dual_clip : Optional [float ] = None ) -> Tuple [namedtuple , namedtuple ]:
377+ def ppo_policy_error_continuous (
378+ data : namedtuple ,
379+ clip_ratio : float = 0.2 ,
380+ dual_clip : Optional [float ] = None ,
381+ kl_type : str = 'k1'
382+ ) -> Tuple [namedtuple , namedtuple ]:
323383 """
324384 Overview:
325385 Implementation of Proximal Policy Optimization (arXiv:1707.06347) with dual_clip
@@ -328,6 +388,7 @@ def ppo_policy_error_continuous(data: namedtuple,
328388 - clip_ratio (:obj:`float`): the ppo clip ratio for the constraint of policy update, defaults to 0.2
329389 - dual_clip (:obj:`float`): a parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf),\
330390 defaults to 5.0, if you don't want to use it, set this parameter to None
391+ - kl_type (:obj:`str`): which kl loss to use, default set to 'k1'.
331392 Returns:
332393 - ppo_loss (:obj:`namedtuple`): the ppo loss item, all of them are the differentiable 0-dim tensor
333394 - ppo_info (:obj:`namedtuple`): the ppo optim information for monitoring, all of them are Python scalar
@@ -353,7 +414,7 @@ def ppo_policy_error_continuous(data: namedtuple,
353414 assert dual_clip is None or dual_clip > 1.0 , "dual_clip value must be greater than 1.0, but get value: {}" .format (
354415 dual_clip
355416 )
356- mu_sigma_new , mu_sigma_old , action , adv , weight = data
417+ mu_sigma_new , mu_sigma_old , action , adv , weight , logit_pretrained = data
357418 if weight is None :
358419 weight = torch .ones_like (adv )
359420
@@ -377,4 +438,13 @@ def ppo_policy_error_continuous(data: namedtuple,
377438 approx_kl = (logp_old - logp_new ).mean ().item ()
378439 clipped = ratio .gt (1 + clip_ratio ) | ratio .lt (1 - clip_ratio )
379440 clipfrac = torch .as_tensor (clipped ).float ().mean ().item ()
380- return ppo_policy_loss (policy_loss , entropy_loss ), ppo_info (approx_kl , clipfrac )
441+
442+ if logit_pretrained is not None :
443+ dist_pretrained = Independent (Normal (logit_pretrained ['mu' ], logit_pretrained ['sigma' ]), 1 )
444+ logp_pretrained = dist_pretrained .log_prob (action )
445+ log_ratio = logp_new - logp_pretrained
446+ kl_div = calculate_kl_div (log_ratio , kl_type )
447+ else :
448+ kl_div = 0
449+
450+ return ppo_policy_loss (policy_loss , entropy_loss , kl_div ), ppo_info (approx_kl , clipfrac )
0 commit comments