Skip to content

Commit 486bb30

Browse files
authored
feature(wyx): add three KL-divergence variants (#870)
* feature(wyx): add three KL-divergence variants * fix bugs and add description for KL-divergence variants * add a period for each comment line * fix bugs and add KL-divergence parameter descriptions * fix flake8 E501 error in ppo.py docstring * fix trailing whitespace in ppo.py
1 parent f6ee768 commit 486bb30

File tree

11 files changed

+149
-38
lines changed

11 files changed

+149
-38
lines changed

ding/policy/ppo.py

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,15 @@ class PPOPolicy(Policy):
7676
grad_clip_value=0.5,
7777
# (bool) Whether ignore done (usually for max step termination env).
7878
ignore_done=False,
79+
# (str) The type of KL divergence loss between current policy and pretrained policy, ['k1', 'k2', 'k3'].
80+
# Reference: http://joschu.net/blog/kl-approx.html
81+
kl_type='k1',
82+
# (float) The weight of KL divergence loss.
83+
kl_beta=0.0,
84+
# (Optional[str]) The path of pretrained model checkpoint.
85+
# If provided, KL regularizer will be calculated between current policy and pretrained policy.
86+
# Default to None, which means KL is not calculated.
87+
pretrained_model_path=None,
7988
),
8089
# collect_mode config
8190
collect=dict(
@@ -186,12 +195,23 @@ def _init_learn(self) -> None:
186195

187196
self._learn_model = model_wrap(self._model, wrapper_name='base')
188197

198+
# load pretrained model
199+
if self._cfg.learn.pretrained_model_path is not None:
200+
self._pretrained_model = copy.deepcopy(self._model)
201+
state_dict = torch.load(self._cfg.learn.pretrained_model_path, map_location='cpu')
202+
self._pretrained_model.load_state_dict(state_dict)
203+
self._pretrained_model.eval()
204+
else:
205+
self._pretrained_model = None
206+
189207
# Algorithm config
190208
self._value_weight = self._cfg.learn.value_weight
191209
self._entropy_weight = self._cfg.learn.entropy_weight
192210
self._clip_ratio = self._cfg.learn.clip_ratio
193211
self._adv_norm = self._cfg.learn.adv_norm
194212
self._value_norm = self._cfg.learn.value_norm
213+
self._kl_type = self._cfg.learn.kl_type
214+
self._kl_beta = self._cfg.learn.kl_beta
195215
if self._value_norm:
196216
self._running_mean_std = RunningMeanStd(epsilon=1e-4, device=self._device)
197217
self._gamma = self._cfg.collect.discount_factor
@@ -285,45 +305,57 @@ def _forward_learn(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
285305
# Normalize advantage in a train_batch
286306
adv = (adv - adv.mean()) / (adv.std() + 1e-8)
287307

308+
if self._pretrained_model is not None:
309+
with torch.no_grad():
310+
logit_pretrained = self._pretrained_model.forward(batch['obs'], mode='compute_actor')['logit']
311+
else:
312+
logit_pretrained = None
313+
288314
# Calculate ppo error
289315
if self._action_space == 'continuous':
290316
ppo_batch = ppo_data(
291317
output['logit'], batch['logit'], batch['action'], output['value'], batch['value'], adv,
292-
batch['return'], batch['weight']
318+
batch['return'], batch['weight'], logit_pretrained
293319
)
294-
ppo_loss, ppo_info = ppo_error_continuous(ppo_batch, self._clip_ratio)
320+
ppo_loss, ppo_info = ppo_error_continuous(ppo_batch, self._clip_ratio, kl_type=self._kl_type)
295321
elif self._action_space == 'discrete':
296322
ppo_batch = ppo_data(
297323
output['logit'], batch['logit'], batch['action'], output['value'], batch['value'], adv,
298-
batch['return'], batch['weight']
324+
batch['return'], batch['weight'], logit_pretrained
299325
)
300-
ppo_loss, ppo_info = ppo_error(ppo_batch, self._clip_ratio)
326+
ppo_loss, ppo_info = ppo_error(ppo_batch, self._clip_ratio, kl_type=self._kl_type)
301327
elif self._action_space == 'hybrid':
302328
# discrete part (discrete policy loss and entropy loss)
303329
ppo_discrete_batch = ppo_policy_data(
304330
output['logit']['action_type'], batch['logit']['action_type'], batch['action']['action_type'],
305331
adv, batch['weight']
306332
)
307-
ppo_discrete_loss, ppo_discrete_info = ppo_policy_error(ppo_discrete_batch, self._clip_ratio)
333+
ppo_discrete_loss, ppo_discrete_info = ppo_policy_error(
334+
ppo_discrete_batch, self._clip_ratio, kl_type=self._kl_type
335+
)
308336
# continuous part (continuous policy loss and entropy loss, value loss)
309337
ppo_continuous_batch = ppo_data(
310338
output['logit']['action_args'], batch['logit']['action_args'], batch['action']['action_args'],
311339
output['value'], batch['value'], adv, batch['return'], batch['weight']
312340
)
313341
ppo_continuous_loss, ppo_continuous_info = ppo_error_continuous(
314-
ppo_continuous_batch, self._clip_ratio
342+
ppo_continuous_batch, self._clip_ratio, kl_type=self._kl_type
315343
)
316344
# sum discrete and continuous loss
317345
ppo_loss = type(ppo_continuous_loss)(
318346
ppo_continuous_loss.policy_loss + ppo_discrete_loss.policy_loss, ppo_continuous_loss.value_loss,
319-
ppo_continuous_loss.entropy_loss + ppo_discrete_loss.entropy_loss
347+
ppo_continuous_loss.entropy_loss + ppo_discrete_loss.entropy_loss, ppo_continuous_loss.kl_div
320348
)
321349
ppo_info = type(ppo_continuous_info)(
322350
max(ppo_continuous_info.approx_kl, ppo_discrete_info.approx_kl),
323351
max(ppo_continuous_info.clipfrac, ppo_discrete_info.clipfrac)
324352
)
325353
wv, we = self._value_weight, self._entropy_weight
326-
total_loss = ppo_loss.policy_loss + wv * ppo_loss.value_loss - we * ppo_loss.entropy_loss
354+
kl_div = ppo_loss.kl_div
355+
total_loss = (
356+
ppo_loss.policy_loss + wv * ppo_loss.value_loss - we * ppo_loss.entropy_loss +
357+
self._kl_beta * kl_div
358+
)
327359

328360
self._optimizer.zero_grad()
329361
total_loss.backward()
@@ -346,6 +378,7 @@ def _forward_learn(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
346378
'value_max': output['value'].max().item(),
347379
'approx_kl': ppo_info.approx_kl,
348380
'clipfrac': ppo_info.clipfrac,
381+
'kl_div': kl_div.item(),
349382
}
350383
if self._action_space == 'continuous':
351384
return_info.update(
@@ -594,6 +627,8 @@ def _monitor_vars_learn(self) -> List[str]:
594627
'value_max',
595628
'value_mean',
596629
]
630+
if self._pretrained_model is not None:
631+
variables += ['kl_div']
597632
if self._action_space == 'continuous':
598633
variables += ['mu_mean', 'sigma_mean', 'sigma_grad', 'act']
599634
return variables

ding/rl_utils/ppo.py

Lines changed: 93 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -6,22 +6,54 @@
66
from ding.hpc_rl import hpc_wrapper
77

88
ppo_data = namedtuple(
9-
'ppo_data', ['logit_new', 'logit_old', 'action', 'value_new', 'value_old', 'adv', 'return_', 'weight']
9+
'ppo_data',
10+
['logit_new', 'logit_old', 'action', 'value_new', 'value_old', 'adv', 'return_', 'weight', 'logit_pretrained']
1011
)
1112
ppo_data_continuous = namedtuple(
12-
'ppo_data_continuous',
13-
['mu_sigma_new', 'mu_sigma_old', 'action', 'value_new', 'value_old', 'adv', 'return_', 'weight']
13+
'ppo_data_continuous', [
14+
'mu_sigma_new', 'mu_sigma_old', 'action', 'value_new', 'value_old', 'adv', 'return_', 'weight',
15+
'logit_pretrained'
16+
]
17+
)
18+
ppo_policy_data = namedtuple(
19+
'ppo_policy_data', ['logit_new', 'logit_old', 'action', 'adv', 'weight', 'logit_pretrained']
1420
)
15-
ppo_policy_data = namedtuple('ppo_policy_data', ['logit_new', 'logit_old', 'action', 'adv', 'weight'])
1621
ppo_policy_data_continuous = namedtuple(
17-
'ppo_policy_data_continuous', ['mu_sigma_new', 'mu_sigma_old', 'action', 'adv', 'weight']
22+
'ppo_policy_data_continuous', ['mu_sigma_new', 'mu_sigma_old', 'action', 'adv', 'weight', 'logit_pretrained']
1823
)
1924
ppo_value_data = namedtuple('ppo_value_data', ['value_new', 'value_old', 'return_', 'weight'])
20-
ppo_loss = namedtuple('ppo_loss', ['policy_loss', 'value_loss', 'entropy_loss'])
21-
ppo_policy_loss = namedtuple('ppo_policy_loss', ['policy_loss', 'entropy_loss'])
25+
ppo_loss = namedtuple('ppo_loss', ['policy_loss', 'value_loss', 'entropy_loss', 'kl_div'])
26+
ppo_policy_loss = namedtuple('ppo_policy_loss', ['policy_loss', 'entropy_loss', 'kl_div'])
2227
ppo_info = namedtuple('ppo_info', ['approx_kl', 'clipfrac'])
2328

2429

30+
def calculate_kl_div(log_ratio: torch.Tensor, kl_type: str) -> torch.Tensor:
31+
"""
32+
Overview:
33+
Calculate different Monte-Carlo estimators for KL-divergence KL(q, p) = E_q[log(q/p)],
34+
where q is the current policy and p is the pretrained policy.
35+
The implementation is based on John Schulman's blog post "Approximating KL Divergence".
36+
Reference: http://joschu.net/blog/kl-approx.html
37+
Arguments:
38+
- log_ratio (:obj:`torch.Tensor`): The log-ratio of probabilities, which should be
39+
log(q/p) = logp_new - logp_pretrained.
40+
- kl_type (:obj:`str`): The type of KL divergence estimator to use.
41+
- 'k1': The standard, unbiased but high-variance estimator: `E_q[log(q/p)]`.
42+
- 'k2': A biased, low-variance estimator from a second-order approximation: `E_q[1/2 * (log(p/q))^2]`.
43+
- 'k3': An unbiased, low-variance estimator: `E_q[(p/q - 1) - log(p/q)]`.
44+
Returns:
45+
- kl_div (:obj:`torch.Tensor`): The calculated KL divergence estimate.
46+
"""
47+
if kl_type == 'k1':
48+
return log_ratio.mean()
49+
elif kl_type == 'k2':
50+
return (log_ratio ** 2 / 2).mean()
51+
elif kl_type == 'k3':
52+
return (torch.exp(-log_ratio) - 1 + log_ratio).mean()
53+
else:
54+
raise ValueError(f"Unknown kl_type: {kl_type}")
55+
56+
2557
def shape_fn_ppo(args, kwargs):
2658
r"""
2759
Overview:
@@ -46,7 +78,8 @@ def ppo_error(
4678
data: namedtuple,
4779
clip_ratio: float = 0.2,
4880
use_value_clip: bool = True,
49-
dual_clip: Optional[float] = None
81+
dual_clip: Optional[float] = None,
82+
kl_type: str = 'k1'
5083
) -> Tuple[namedtuple, namedtuple]:
5184
"""
5285
Overview:
@@ -57,6 +90,7 @@ def ppo_error(
5790
- use_value_clip (:obj:`bool`): whether to use clip in value loss with the same ratio as policy
5891
- dual_clip (:obj:`float`): a parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf),\
5992
defaults to 5.0, if you don't want to use it, set this parameter to None
93+
- kl_type (:obj:`str`): which kl loss to use, default set to 'k1'.
6094
Returns:
6195
- ppo_loss (:obj:`namedtuple`): the ppo loss item, all of them are the differentiable 0-dim tensor
6296
- ppo_info (:obj:`namedtuple`): the ppo optim information for monitoring, all of them are Python scalar
@@ -95,20 +129,23 @@ def ppo_error(
95129
assert dual_clip is None or dual_clip > 1.0, "dual_clip value must be greater than 1.0, but get value: {}".format(
96130
dual_clip
97131
)
98-
logit_new, logit_old, action, value_new, value_old, adv, return_, weight = data
99-
policy_data = ppo_policy_data(logit_new, logit_old, action, adv, weight)
100-
policy_output, policy_info = ppo_policy_error(policy_data, clip_ratio, dual_clip)
132+
logit_new, logit_old, action, value_new, value_old, adv, return_, weight, logit_pretrained = data
133+
policy_data = ppo_policy_data(logit_new, logit_old, action, adv, weight, logit_pretrained)
134+
policy_output, policy_info = ppo_policy_error(policy_data, clip_ratio, dual_clip, kl_type=kl_type)
101135
value_data = ppo_value_data(value_new, value_old, return_, weight)
102136
value_loss = ppo_value_error(value_data, clip_ratio, use_value_clip)
103137

104-
return ppo_loss(policy_output.policy_loss, value_loss, policy_output.entropy_loss), policy_info
138+
return ppo_loss(
139+
policy_output.policy_loss, value_loss, policy_output.entropy_loss, policy_output.kl_div
140+
), policy_info
105141

106142

107143
def ppo_policy_error(
108144
data: namedtuple,
109145
clip_ratio: float = 0.2,
110146
dual_clip: Optional[float] = None,
111-
entropy_bonus: bool = True
147+
entropy_bonus: bool = True,
148+
kl_type: str = 'k1'
112149
) -> Tuple[namedtuple, namedtuple]:
113150
"""
114151
Overview:
@@ -119,6 +156,7 @@ def ppo_policy_error(
119156
- dual_clip (:obj:`float`): A parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf), \
120157
defaults to 5.0, if you don't want to use it, set this parameter to None
121158
- entropy_bonus (:obj:`bool`): Whether to use entropy bonus, defaults to True. LLM RLHF usually does not use it.
159+
- kl_type (:obj:`str`): which kl loss to use, default set to 'k1'.
122160
Returns:
123161
- ppo_policy_loss (:obj:`namedtuple`): the ppo policy loss item, all of them are the differentiable 0-dim tensor
124162
- ppo_info (:obj:`namedtuple`): the ppo optim information for monitoring, all of them are Python scalar
@@ -148,7 +186,7 @@ def ppo_policy_error(
148186
.. note::
149187
For the action mask often used in LLM/VLM, users can set the `weight` to the action mask.
150188
"""
151-
logit_new, logit_old, action, adv, weight = data
189+
logit_new, logit_old, action, adv, weight, logit_pretrained = data
152190
if weight is None:
153191
weight = torch.ones_like(adv)
154192
dist_new = torch.distributions.categorical.Categorical(logits=logit_new)
@@ -180,7 +218,16 @@ def ppo_policy_error(
180218
approx_kl = (logp_old - logp_new).mean().item()
181219
clipped = ratio.gt(1 + clip_ratio) | ratio.lt(1 - clip_ratio)
182220
clipfrac = torch.as_tensor(clipped).float().mean().item()
183-
return ppo_policy_loss(policy_loss, entropy_loss), ppo_info(approx_kl, clipfrac)
221+
222+
if logit_pretrained is not None:
223+
dist_pretrained = torch.distributions.categorical.Categorical(logits=logit_pretrained)
224+
logp_pretrained = dist_pretrained.log_prob(action)
225+
log_ratio = logp_new - logp_pretrained
226+
kl_div = calculate_kl_div(log_ratio, kl_type)
227+
else:
228+
kl_div = 0
229+
230+
return ppo_policy_loss(policy_loss, entropy_loss, kl_div), ppo_info(approx_kl, clipfrac)
184231

185232

186233
def ppo_value_error(
@@ -232,7 +279,8 @@ def ppo_error_continuous(
232279
data: namedtuple,
233280
clip_ratio: float = 0.2,
234281
use_value_clip: bool = True,
235-
dual_clip: Optional[float] = None
282+
dual_clip: Optional[float] = None,
283+
kl_type: str = 'k1'
236284
) -> Tuple[namedtuple, namedtuple]:
237285
"""
238286
Overview:
@@ -243,6 +291,7 @@ def ppo_error_continuous(
243291
- use_value_clip (:obj:`bool`): whether to use clip in value loss with the same ratio as policy
244292
- dual_clip (:obj:`float`): a parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf),\
245293
defaults to 5.0, if you don't want to use it, set this parameter to None
294+
- kl_type (:obj:`str`): which kl loss to use, default set to 'k1'.
246295
Returns:
247296
- ppo_loss (:obj:`namedtuple`): the ppo loss item, all of them are the differentiable 0-dim tensor
248297
- ppo_info (:obj:`namedtuple`): the ppo optim information for monitoring, all of them are Python scalar
@@ -281,7 +330,7 @@ def ppo_error_continuous(
281330
assert dual_clip is None or dual_clip > 1.0, "dual_clip value must be greater than 1.0, but get value: {}".format(
282331
dual_clip
283332
)
284-
mu_sigma_new, mu_sigma_old, action, value_new, value_old, adv, return_, weight = data
333+
mu_sigma_new, mu_sigma_old, action, value_new, value_old, adv, return_, weight, logit_pretrained = data
285334
if weight is None:
286335
weight = torch.ones_like(adv)
287336

@@ -314,12 +363,23 @@ def ppo_error_continuous(
314363
else:
315364
value_loss = 0.5 * ((return_ - value_new).pow(2) * weight).mean()
316365

317-
return ppo_loss(policy_loss, value_loss, entropy_loss), ppo_info(approx_kl, clipfrac)
366+
if logit_pretrained is not None:
367+
dist_pretrained = Independent(Normal(logit_pretrained['mu'], logit_pretrained['sigma']), 1)
368+
logp_pretrained = dist_pretrained.log_prob(action)
369+
log_ratio = logp_new - logp_pretrained
370+
kl_div = calculate_kl_div(log_ratio, kl_type)
371+
else:
372+
kl_div = 0
373+
374+
return ppo_loss(policy_loss, value_loss, entropy_loss, kl_div), ppo_info(approx_kl, clipfrac)
318375

319376

320-
def ppo_policy_error_continuous(data: namedtuple,
321-
clip_ratio: float = 0.2,
322-
dual_clip: Optional[float] = None) -> Tuple[namedtuple, namedtuple]:
377+
def ppo_policy_error_continuous(
378+
data: namedtuple,
379+
clip_ratio: float = 0.2,
380+
dual_clip: Optional[float] = None,
381+
kl_type: str = 'k1'
382+
) -> Tuple[namedtuple, namedtuple]:
323383
"""
324384
Overview:
325385
Implementation of Proximal Policy Optimization (arXiv:1707.06347) with dual_clip
@@ -328,6 +388,7 @@ def ppo_policy_error_continuous(data: namedtuple,
328388
- clip_ratio (:obj:`float`): the ppo clip ratio for the constraint of policy update, defaults to 0.2
329389
- dual_clip (:obj:`float`): a parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf),\
330390
defaults to 5.0, if you don't want to use it, set this parameter to None
391+
- kl_type (:obj:`str`): which kl loss to use, default set to 'k1'.
331392
Returns:
332393
- ppo_loss (:obj:`namedtuple`): the ppo loss item, all of them are the differentiable 0-dim tensor
333394
- ppo_info (:obj:`namedtuple`): the ppo optim information for monitoring, all of them are Python scalar
@@ -353,7 +414,7 @@ def ppo_policy_error_continuous(data: namedtuple,
353414
assert dual_clip is None or dual_clip > 1.0, "dual_clip value must be greater than 1.0, but get value: {}".format(
354415
dual_clip
355416
)
356-
mu_sigma_new, mu_sigma_old, action, adv, weight = data
417+
mu_sigma_new, mu_sigma_old, action, adv, weight, logit_pretrained = data
357418
if weight is None:
358419
weight = torch.ones_like(adv)
359420

@@ -377,4 +438,13 @@ def ppo_policy_error_continuous(data: namedtuple,
377438
approx_kl = (logp_old - logp_new).mean().item()
378439
clipped = ratio.gt(1 + clip_ratio) | ratio.lt(1 - clip_ratio)
379440
clipfrac = torch.as_tensor(clipped).float().mean().item()
380-
return ppo_policy_loss(policy_loss, entropy_loss), ppo_info(approx_kl, clipfrac)
441+
442+
if logit_pretrained is not None:
443+
dist_pretrained = Independent(Normal(logit_pretrained['mu'], logit_pretrained['sigma']), 1)
444+
logp_pretrained = dist_pretrained.log_prob(action)
445+
log_ratio = logp_new - logp_pretrained
446+
kl_div = calculate_kl_div(log_ratio, kl_type)
447+
else:
448+
kl_div = 0
449+
450+
return ppo_policy_loss(policy_loss, entropy_loss, kl_div), ppo_info(approx_kl, clipfrac)

dizoo/atari/config/serial/pong/pong_ppo_config.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from easydict import EasyDict
22

33
pong_ppo_config = dict(
4+
exp_name='pong_ppo_seed0',
45
env=dict(
56
collector_env_num=8,
67
evaluator_env_num=8,
@@ -39,6 +40,12 @@
3940
ignore_done=False,
4041
grad_clip_type='clip_norm',
4142
grad_clip_value=0.5,
43+
# KL divergence regularization between current policy and pretrained policy.
44+
# Supported KL divergence estimators: ['k1', 'k2', 'k3'].
45+
# KL divergence loss will be calculated only when pretrained_model_path is provided.
46+
kl_beta=0.01,
47+
kl_type='k1',
48+
pretrained_model_path=None,
4249
),
4350
collect=dict(
4451
n_sample=3200,

dizoo/atari/config/serial/spaceinvaders/spaceinvaders_onppo_config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,12 @@
4444
ignore_done=False,
4545
grad_clip_type='clip_norm',
4646
grad_clip_value=0.5,
47+
# KL divergence regularization between current policy and pretrained policy.
48+
# Supported KL divergence estimators: ['k1', 'k2', 'k3'].
49+
# KL divergence loss will be calculated only when pretrained_model_path is provided.
50+
kl_beta=0.05,
51+
kl_type='k1',
52+
pretrained_model_path=None,
4753
),
4854
collect=dict(
4955
n_sample=1024,

dizoo/classic_control/cartpole/config/cartpole_dqn_ddp_config.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,4 +63,3 @@
6363
from ding.entry import serial_pipeline
6464
with DDPContext():
6565
serial_pipeline((main_config, create_config), seed=0)
66-

0 commit comments

Comments
 (0)