Skip to content

Commit 471e351

Browse files
committed
feature(wyx): add three KL-divergence variants
1 parent f6ee768 commit 471e351

File tree

4 files changed

+78
-17
lines changed

4 files changed

+78
-17
lines changed

ding/policy/ppo.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,10 @@ class PPOPolicy(Policy):
7676
grad_clip_value=0.5,
7777
# (bool) Whether ignore done (usually for max step termination env).
7878
ignore_done=False,
79+
# (str) The type of KL divergence loss, ['k1', 'k2', 'k3']
80+
kl_type='k1',
81+
# (float) The weight of KL divergence loss.
82+
kl_beta=0.0,
7983
),
8084
# collect_mode config
8185
collect=dict(
@@ -192,6 +196,8 @@ def _init_learn(self) -> None:
192196
self._clip_ratio = self._cfg.learn.clip_ratio
193197
self._adv_norm = self._cfg.learn.adv_norm
194198
self._value_norm = self._cfg.learn.value_norm
199+
self._kl_type = self._cfg.learn.kl_type
200+
self._kl_beta = self._cfg.learn.kl_beta
195201
if self._value_norm:
196202
self._running_mean_std = RunningMeanStd(epsilon=1e-4, device=self._device)
197203
self._gamma = self._cfg.collect.discount_factor
@@ -291,27 +297,29 @@ def _forward_learn(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
291297
output['logit'], batch['logit'], batch['action'], output['value'], batch['value'], adv,
292298
batch['return'], batch['weight']
293299
)
294-
ppo_loss, ppo_info = ppo_error_continuous(ppo_batch, self._clip_ratio)
300+
ppo_loss, ppo_info = ppo_error_continuous(ppo_batch, self._clip_ratio, kl_type=self._kl_type)
295301
elif self._action_space == 'discrete':
296302
ppo_batch = ppo_data(
297303
output['logit'], batch['logit'], batch['action'], output['value'], batch['value'], adv,
298304
batch['return'], batch['weight']
299305
)
300-
ppo_loss, ppo_info = ppo_error(ppo_batch, self._clip_ratio)
306+
ppo_loss, ppo_info = ppo_error(ppo_batch, self._clip_ratio, kl_type=self._kl_type)
301307
elif self._action_space == 'hybrid':
302308
# discrete part (discrete policy loss and entropy loss)
303309
ppo_discrete_batch = ppo_policy_data(
304310
output['logit']['action_type'], batch['logit']['action_type'], batch['action']['action_type'],
305311
adv, batch['weight']
306312
)
307-
ppo_discrete_loss, ppo_discrete_info = ppo_policy_error(ppo_discrete_batch, self._clip_ratio)
313+
ppo_discrete_loss, ppo_discrete_info = ppo_policy_error(
314+
ppo_discrete_batch, self._clip_ratio, kl_type=self._kl_type
315+
)
308316
# continuous part (continuous policy loss and entropy loss, value loss)
309317
ppo_continuous_batch = ppo_data(
310318
output['logit']['action_args'], batch['logit']['action_args'], batch['action']['action_args'],
311319
output['value'], batch['value'], adv, batch['return'], batch['weight']
312320
)
313321
ppo_continuous_loss, ppo_continuous_info = ppo_error_continuous(
314-
ppo_continuous_batch, self._clip_ratio
322+
ppo_continuous_batch, self._clip_ratio, kl_type=self._kl_type
315323
)
316324
# sum discrete and continuous loss
317325
ppo_loss = type(ppo_continuous_loss)(
@@ -320,10 +328,15 @@ def _forward_learn(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
320328
)
321329
ppo_info = type(ppo_continuous_info)(
322330
max(ppo_continuous_info.approx_kl, ppo_discrete_info.approx_kl),
323-
max(ppo_continuous_info.clipfrac, ppo_discrete_info.clipfrac)
331+
max(ppo_continuous_info.clipfrac, ppo_discrete_info.clipfrac), ppo_continuous_info.kl_div
324332
)
325333
wv, we = self._value_weight, self._entropy_weight
326-
total_loss = ppo_loss.policy_loss + wv * ppo_loss.value_loss - we * ppo_loss.entropy_loss
334+
kl_div = ppo_info.kl_div
335+
# 正确的、符合规范的修改
336+
total_loss = (
337+
ppo_loss.policy_loss + wv * ppo_loss.value_loss - we * ppo_loss.entropy_loss +
338+
self._kl_beta * kl_div
339+
)
327340

328341
self._optimizer.zero_grad()
329342
total_loss.backward()
@@ -346,6 +359,7 @@ def _forward_learn(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
346359
'value_max': output['value'].max().item(),
347360
'approx_kl': ppo_info.approx_kl,
348361
'clipfrac': ppo_info.clipfrac,
362+
'kl_div': kl_div.item(),
349363
}
350364
if self._action_space == 'continuous':
351365
return_info.update(
@@ -593,6 +607,7 @@ def _monitor_vars_learn(self) -> List[str]:
593607
'clipfrac',
594608
'value_max',
595609
'value_mean',
610+
'kl_div',
596611
]
597612
if self._action_space == 'continuous':
598613
variables += ['mu_mean', 'sigma_mean', 'sigma_grad', 'act']

ding/rl_utils/ppo.py

Lines changed: 53 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
ppo_value_data = namedtuple('ppo_value_data', ['value_new', 'value_old', 'return_', 'weight'])
2020
ppo_loss = namedtuple('ppo_loss', ['policy_loss', 'value_loss', 'entropy_loss'])
2121
ppo_policy_loss = namedtuple('ppo_policy_loss', ['policy_loss', 'entropy_loss'])
22-
ppo_info = namedtuple('ppo_info', ['approx_kl', 'clipfrac'])
22+
ppo_info = namedtuple('ppo_info', ['approx_kl', 'clipfrac', 'kl_div'])
2323

2424

2525
def shape_fn_ppo(args, kwargs):
@@ -46,7 +46,8 @@ def ppo_error(
4646
data: namedtuple,
4747
clip_ratio: float = 0.2,
4848
use_value_clip: bool = True,
49-
dual_clip: Optional[float] = None
49+
dual_clip: Optional[float] = None,
50+
kl_type: str = 'k1'
5051
) -> Tuple[namedtuple, namedtuple]:
5152
"""
5253
Overview:
@@ -57,6 +58,7 @@ def ppo_error(
5758
- use_value_clip (:obj:`bool`): whether to use clip in value loss with the same ratio as policy
5859
- dual_clip (:obj:`float`): a parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf),\
5960
defaults to 5.0, if you don't want to use it, set this parameter to None
61+
- kl_type (:obj:`str`): which kl loss to use, default set to 'approx'
6062
Returns:
6163
- ppo_loss (:obj:`namedtuple`): the ppo loss item, all of them are the differentiable 0-dim tensor
6264
- ppo_info (:obj:`namedtuple`): the ppo optim information for monitoring, all of them are Python scalar
@@ -97,7 +99,7 @@ def ppo_error(
9799
)
98100
logit_new, logit_old, action, value_new, value_old, adv, return_, weight = data
99101
policy_data = ppo_policy_data(logit_new, logit_old, action, adv, weight)
100-
policy_output, policy_info = ppo_policy_error(policy_data, clip_ratio, dual_clip)
102+
policy_output, policy_info = ppo_policy_error(policy_data, clip_ratio, dual_clip, kl_type=kl_type)
101103
value_data = ppo_value_data(value_new, value_old, return_, weight)
102104
value_loss = ppo_value_error(value_data, clip_ratio, use_value_clip)
103105

@@ -108,7 +110,8 @@ def ppo_policy_error(
108110
data: namedtuple,
109111
clip_ratio: float = 0.2,
110112
dual_clip: Optional[float] = None,
111-
entropy_bonus: bool = True
113+
entropy_bonus: bool = True,
114+
kl_type: str = 'k1'
112115
) -> Tuple[namedtuple, namedtuple]:
113116
"""
114117
Overview:
@@ -119,6 +122,7 @@ def ppo_policy_error(
119122
- dual_clip (:obj:`float`): A parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf), \
120123
defaults to 5.0, if you don't want to use it, set this parameter to None
121124
- entropy_bonus (:obj:`bool`): Whether to use entropy bonus, defaults to True. LLM RLHF usually does not use it.
125+
- kl_type (:obj:`str`): which kl loss to use, default set to 'k1'
122126
Returns:
123127
- ppo_policy_loss (:obj:`namedtuple`): the ppo policy loss item, all of them are the differentiable 0-dim tensor
124128
- ppo_info (:obj:`namedtuple`): the ppo optim information for monitoring, all of them are Python scalar
@@ -180,7 +184,18 @@ def ppo_policy_error(
180184
approx_kl = (logp_old - logp_new).mean().item()
181185
clipped = ratio.gt(1 + clip_ratio) | ratio.lt(1 - clip_ratio)
182186
clipfrac = torch.as_tensor(clipped).float().mean().item()
183-
return ppo_policy_loss(policy_loss, entropy_loss), ppo_info(approx_kl, clipfrac)
187+
188+
logr = logp_old - logp_new
189+
if kl_type == 'k1':
190+
kl_div = logr.mean()
191+
elif kl_type == 'k2':
192+
kl_div = (logr ** 2 / 2).mean()
193+
elif kl_type == 'k3':
194+
kl_div = (torch.exp(-logr) - 1 + logr).mean()
195+
else:
196+
raise ValueError(f"Unknown kl_type: {kl_type}")
197+
198+
return ppo_policy_loss(policy_loss, entropy_loss), ppo_info(approx_kl, clipfrac, kl_div)
184199

185200

186201
def ppo_value_error(
@@ -232,7 +247,8 @@ def ppo_error_continuous(
232247
data: namedtuple,
233248
clip_ratio: float = 0.2,
234249
use_value_clip: bool = True,
235-
dual_clip: Optional[float] = None
250+
dual_clip: Optional[float] = None,
251+
kl_type: str = 'k1'
236252
) -> Tuple[namedtuple, namedtuple]:
237253
"""
238254
Overview:
@@ -243,6 +259,7 @@ def ppo_error_continuous(
243259
- use_value_clip (:obj:`bool`): whether to use clip in value loss with the same ratio as policy
244260
- dual_clip (:obj:`float`): a parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf),\
245261
defaults to 5.0, if you don't want to use it, set this parameter to None
262+
- kl_type (:obj:`str`): which kl loss to use, default set to 'k1'
246263
Returns:
247264
- ppo_loss (:obj:`namedtuple`): the ppo loss item, all of them are the differentiable 0-dim tensor
248265
- ppo_info (:obj:`namedtuple`): the ppo optim information for monitoring, all of them are Python scalar
@@ -314,12 +331,25 @@ def ppo_error_continuous(
314331
else:
315332
value_loss = 0.5 * ((return_ - value_new).pow(2) * weight).mean()
316333

317-
return ppo_loss(policy_loss, value_loss, entropy_loss), ppo_info(approx_kl, clipfrac)
334+
logr = logp_old - logp_new
335+
if kl_type == 'k1':
336+
kl_div = logr.mean()
337+
elif kl_type == 'k2':
338+
kl_div = (logr ** 2 / 2).mean()
339+
elif kl_type == 'k3':
340+
kl_div = (torch.exp(-logr) - 1 + logr).mean()
341+
else:
342+
raise ValueError(f"Unknown kl_type: {kl_type}")
343+
344+
return ppo_loss(policy_loss, value_loss, entropy_loss), ppo_info(approx_kl, clipfrac, kl_div)
318345

319346

320-
def ppo_policy_error_continuous(data: namedtuple,
321-
clip_ratio: float = 0.2,
322-
dual_clip: Optional[float] = None) -> Tuple[namedtuple, namedtuple]:
347+
def ppo_policy_error_continuous(
348+
data: namedtuple,
349+
clip_ratio: float = 0.2,
350+
dual_clip: Optional[float] = None,
351+
kl_type: str = 'k1'
352+
) -> Tuple[namedtuple, namedtuple]:
323353
"""
324354
Overview:
325355
Implementation of Proximal Policy Optimization (arXiv:1707.06347) with dual_clip
@@ -328,6 +358,7 @@ def ppo_policy_error_continuous(data: namedtuple,
328358
- clip_ratio (:obj:`float`): the ppo clip ratio for the constraint of policy update, defaults to 0.2
329359
- dual_clip (:obj:`float`): a parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf),\
330360
defaults to 5.0, if you don't want to use it, set this parameter to None
361+
- kl_type (:obj:`str`): which kl loss to use, default set to 'k1'
331362
Returns:
332363
- ppo_loss (:obj:`namedtuple`): the ppo loss item, all of them are the differentiable 0-dim tensor
333364
- ppo_info (:obj:`namedtuple`): the ppo optim information for monitoring, all of them are Python scalar
@@ -377,4 +408,15 @@ def ppo_policy_error_continuous(data: namedtuple,
377408
approx_kl = (logp_old - logp_new).mean().item()
378409
clipped = ratio.gt(1 + clip_ratio) | ratio.lt(1 - clip_ratio)
379410
clipfrac = torch.as_tensor(clipped).float().mean().item()
380-
return ppo_policy_loss(policy_loss, entropy_loss), ppo_info(approx_kl, clipfrac)
411+
412+
logr = logp_old - logp_new
413+
if kl_type == 'k1':
414+
kl_div = logr.mean()
415+
elif kl_type == 'k2':
416+
kl_div = (logr ** 2 / 2).mean()
417+
elif kl_type == 'k3':
418+
kl_div = (torch.exp(-logr) - 1 + logr).mean()
419+
else:
420+
raise ValueError(f"Unknown kl_type: {kl_type}")
421+
422+
return ppo_policy_loss(policy_loss, entropy_loss), ppo_info(approx_kl, clipfrac, kl_div)

dizoo/atari/config/serial/pong/pong_ppo_config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@
3939
ignore_done=False,
4040
grad_clip_type='clip_norm',
4141
grad_clip_value=0.5,
42+
kl_beta=0.01,
43+
kl_type='k1',
4244
),
4345
collect=dict(
4446
n_sample=3200,

dizoo/atari/config/serial/spaceinvaders/spaceinvaders_onppo_config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@
4444
ignore_done=False,
4545
grad_clip_type='clip_norm',
4646
grad_clip_value=0.5,
47+
kl_beta=0.05,
48+
kl_type='k1',
4749
),
4850
collect=dict(
4951
n_sample=1024,

0 commit comments

Comments
 (0)