From 471e35170bb5cdd3e8a3c20ee351269ec48e9e67 Mon Sep 17 00:00:00 2001 From: wyx Date: Wed, 16 Jul 2025 11:26:02 +0000 Subject: [PATCH 1/6] feature(wyx): add three KL-divergence variants --- ding/policy/ppo.py | 27 ++++++-- ding/rl_utils/ppo.py | 64 +++++++++++++++---- .../config/serial/pong/pong_ppo_config.py | 2 + .../spaceinvaders_onppo_config.py | 2 + 4 files changed, 78 insertions(+), 17 deletions(-) diff --git a/ding/policy/ppo.py b/ding/policy/ppo.py index 958cba1d83..902c4f053f 100644 --- a/ding/policy/ppo.py +++ b/ding/policy/ppo.py @@ -76,6 +76,10 @@ class PPOPolicy(Policy): grad_clip_value=0.5, # (bool) Whether ignore done (usually for max step termination env). ignore_done=False, + # (str) The type of KL divergence loss, ['k1', 'k2', 'k3'] + kl_type='k1', + # (float) The weight of KL divergence loss. + kl_beta=0.0, ), # collect_mode config collect=dict( @@ -192,6 +196,8 @@ def _init_learn(self) -> None: self._clip_ratio = self._cfg.learn.clip_ratio self._adv_norm = self._cfg.learn.adv_norm self._value_norm = self._cfg.learn.value_norm + self._kl_type = self._cfg.learn.kl_type + self._kl_beta = self._cfg.learn.kl_beta if self._value_norm: self._running_mean_std = RunningMeanStd(epsilon=1e-4, device=self._device) self._gamma = self._cfg.collect.discount_factor @@ -291,27 +297,29 @@ def _forward_learn(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: output['logit'], batch['logit'], batch['action'], output['value'], batch['value'], adv, batch['return'], batch['weight'] ) - ppo_loss, ppo_info = ppo_error_continuous(ppo_batch, self._clip_ratio) + ppo_loss, ppo_info = ppo_error_continuous(ppo_batch, self._clip_ratio, kl_type=self._kl_type) elif self._action_space == 'discrete': ppo_batch = ppo_data( output['logit'], batch['logit'], batch['action'], output['value'], batch['value'], adv, batch['return'], batch['weight'] ) - ppo_loss, ppo_info = ppo_error(ppo_batch, self._clip_ratio) + ppo_loss, ppo_info = ppo_error(ppo_batch, self._clip_ratio, kl_type=self._kl_type) elif self._action_space == 'hybrid': # discrete part (discrete policy loss and entropy loss) ppo_discrete_batch = ppo_policy_data( output['logit']['action_type'], batch['logit']['action_type'], batch['action']['action_type'], adv, batch['weight'] ) - ppo_discrete_loss, ppo_discrete_info = ppo_policy_error(ppo_discrete_batch, self._clip_ratio) + ppo_discrete_loss, ppo_discrete_info = ppo_policy_error( + ppo_discrete_batch, self._clip_ratio, kl_type=self._kl_type + ) # continuous part (continuous policy loss and entropy loss, value loss) ppo_continuous_batch = ppo_data( output['logit']['action_args'], batch['logit']['action_args'], batch['action']['action_args'], output['value'], batch['value'], adv, batch['return'], batch['weight'] ) ppo_continuous_loss, ppo_continuous_info = ppo_error_continuous( - ppo_continuous_batch, self._clip_ratio + ppo_continuous_batch, self._clip_ratio, kl_type=self._kl_type ) # sum discrete and continuous loss ppo_loss = type(ppo_continuous_loss)( @@ -320,10 +328,15 @@ def _forward_learn(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: ) ppo_info = type(ppo_continuous_info)( max(ppo_continuous_info.approx_kl, ppo_discrete_info.approx_kl), - max(ppo_continuous_info.clipfrac, ppo_discrete_info.clipfrac) + max(ppo_continuous_info.clipfrac, ppo_discrete_info.clipfrac), ppo_continuous_info.kl_div ) wv, we = self._value_weight, self._entropy_weight - total_loss = ppo_loss.policy_loss + wv * ppo_loss.value_loss - we * ppo_loss.entropy_loss + kl_div = ppo_info.kl_div + # 正确的、符合规范的修改 + total_loss = ( + ppo_loss.policy_loss + wv * ppo_loss.value_loss - we * ppo_loss.entropy_loss + + self._kl_beta * kl_div + ) self._optimizer.zero_grad() total_loss.backward() @@ -346,6 +359,7 @@ def _forward_learn(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: 'value_max': output['value'].max().item(), 'approx_kl': ppo_info.approx_kl, 'clipfrac': ppo_info.clipfrac, + 'kl_div': kl_div.item(), } if self._action_space == 'continuous': return_info.update( @@ -593,6 +607,7 @@ def _monitor_vars_learn(self) -> List[str]: 'clipfrac', 'value_max', 'value_mean', + 'kl_div', ] if self._action_space == 'continuous': variables += ['mu_mean', 'sigma_mean', 'sigma_grad', 'act'] diff --git a/ding/rl_utils/ppo.py b/ding/rl_utils/ppo.py index c88c647e7c..57fb9c3c35 100644 --- a/ding/rl_utils/ppo.py +++ b/ding/rl_utils/ppo.py @@ -19,7 +19,7 @@ ppo_value_data = namedtuple('ppo_value_data', ['value_new', 'value_old', 'return_', 'weight']) ppo_loss = namedtuple('ppo_loss', ['policy_loss', 'value_loss', 'entropy_loss']) ppo_policy_loss = namedtuple('ppo_policy_loss', ['policy_loss', 'entropy_loss']) -ppo_info = namedtuple('ppo_info', ['approx_kl', 'clipfrac']) +ppo_info = namedtuple('ppo_info', ['approx_kl', 'clipfrac', 'kl_div']) def shape_fn_ppo(args, kwargs): @@ -46,7 +46,8 @@ def ppo_error( data: namedtuple, clip_ratio: float = 0.2, use_value_clip: bool = True, - dual_clip: Optional[float] = None + dual_clip: Optional[float] = None, + kl_type: str = 'k1' ) -> Tuple[namedtuple, namedtuple]: """ Overview: @@ -57,6 +58,7 @@ def ppo_error( - use_value_clip (:obj:`bool`): whether to use clip in value loss with the same ratio as policy - dual_clip (:obj:`float`): a parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf),\ defaults to 5.0, if you don't want to use it, set this parameter to None + - kl_type (:obj:`str`): which kl loss to use, default set to 'approx' Returns: - ppo_loss (:obj:`namedtuple`): the ppo loss item, all of them are the differentiable 0-dim tensor - ppo_info (:obj:`namedtuple`): the ppo optim information for monitoring, all of them are Python scalar @@ -97,7 +99,7 @@ def ppo_error( ) logit_new, logit_old, action, value_new, value_old, adv, return_, weight = data policy_data = ppo_policy_data(logit_new, logit_old, action, adv, weight) - policy_output, policy_info = ppo_policy_error(policy_data, clip_ratio, dual_clip) + policy_output, policy_info = ppo_policy_error(policy_data, clip_ratio, dual_clip, kl_type=kl_type) value_data = ppo_value_data(value_new, value_old, return_, weight) value_loss = ppo_value_error(value_data, clip_ratio, use_value_clip) @@ -108,7 +110,8 @@ def ppo_policy_error( data: namedtuple, clip_ratio: float = 0.2, dual_clip: Optional[float] = None, - entropy_bonus: bool = True + entropy_bonus: bool = True, + kl_type: str = 'k1' ) -> Tuple[namedtuple, namedtuple]: """ Overview: @@ -119,6 +122,7 @@ def ppo_policy_error( - dual_clip (:obj:`float`): A parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf), \ defaults to 5.0, if you don't want to use it, set this parameter to None - entropy_bonus (:obj:`bool`): Whether to use entropy bonus, defaults to True. LLM RLHF usually does not use it. + - kl_type (:obj:`str`): which kl loss to use, default set to 'k1' Returns: - ppo_policy_loss (:obj:`namedtuple`): the ppo policy loss item, all of them are the differentiable 0-dim tensor - ppo_info (:obj:`namedtuple`): the ppo optim information for monitoring, all of them are Python scalar @@ -180,7 +184,18 @@ def ppo_policy_error( approx_kl = (logp_old - logp_new).mean().item() clipped = ratio.gt(1 + clip_ratio) | ratio.lt(1 - clip_ratio) clipfrac = torch.as_tensor(clipped).float().mean().item() - return ppo_policy_loss(policy_loss, entropy_loss), ppo_info(approx_kl, clipfrac) + + logr = logp_old - logp_new + if kl_type == 'k1': + kl_div = logr.mean() + elif kl_type == 'k2': + kl_div = (logr ** 2 / 2).mean() + elif kl_type == 'k3': + kl_div = (torch.exp(-logr) - 1 + logr).mean() + else: + raise ValueError(f"Unknown kl_type: {kl_type}") + + return ppo_policy_loss(policy_loss, entropy_loss), ppo_info(approx_kl, clipfrac, kl_div) def ppo_value_error( @@ -232,7 +247,8 @@ def ppo_error_continuous( data: namedtuple, clip_ratio: float = 0.2, use_value_clip: bool = True, - dual_clip: Optional[float] = None + dual_clip: Optional[float] = None, + kl_type: str = 'k1' ) -> Tuple[namedtuple, namedtuple]: """ Overview: @@ -243,6 +259,7 @@ def ppo_error_continuous( - use_value_clip (:obj:`bool`): whether to use clip in value loss with the same ratio as policy - dual_clip (:obj:`float`): a parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf),\ defaults to 5.0, if you don't want to use it, set this parameter to None + - kl_type (:obj:`str`): which kl loss to use, default set to 'k1' Returns: - ppo_loss (:obj:`namedtuple`): the ppo loss item, all of them are the differentiable 0-dim tensor - ppo_info (:obj:`namedtuple`): the ppo optim information for monitoring, all of them are Python scalar @@ -314,12 +331,25 @@ def ppo_error_continuous( else: value_loss = 0.5 * ((return_ - value_new).pow(2) * weight).mean() - return ppo_loss(policy_loss, value_loss, entropy_loss), ppo_info(approx_kl, clipfrac) + logr = logp_old - logp_new + if kl_type == 'k1': + kl_div = logr.mean() + elif kl_type == 'k2': + kl_div = (logr ** 2 / 2).mean() + elif kl_type == 'k3': + kl_div = (torch.exp(-logr) - 1 + logr).mean() + else: + raise ValueError(f"Unknown kl_type: {kl_type}") + + return ppo_loss(policy_loss, value_loss, entropy_loss), ppo_info(approx_kl, clipfrac, kl_div) -def ppo_policy_error_continuous(data: namedtuple, - clip_ratio: float = 0.2, - dual_clip: Optional[float] = None) -> Tuple[namedtuple, namedtuple]: +def ppo_policy_error_continuous( + data: namedtuple, + clip_ratio: float = 0.2, + dual_clip: Optional[float] = None, + kl_type: str = 'k1' +) -> Tuple[namedtuple, namedtuple]: """ Overview: Implementation of Proximal Policy Optimization (arXiv:1707.06347) with dual_clip @@ -328,6 +358,7 @@ def ppo_policy_error_continuous(data: namedtuple, - clip_ratio (:obj:`float`): the ppo clip ratio for the constraint of policy update, defaults to 0.2 - dual_clip (:obj:`float`): a parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf),\ defaults to 5.0, if you don't want to use it, set this parameter to None + - kl_type (:obj:`str`): which kl loss to use, default set to 'k1' Returns: - ppo_loss (:obj:`namedtuple`): the ppo loss item, all of them are the differentiable 0-dim tensor - ppo_info (:obj:`namedtuple`): the ppo optim information for monitoring, all of them are Python scalar @@ -377,4 +408,15 @@ def ppo_policy_error_continuous(data: namedtuple, approx_kl = (logp_old - logp_new).mean().item() clipped = ratio.gt(1 + clip_ratio) | ratio.lt(1 - clip_ratio) clipfrac = torch.as_tensor(clipped).float().mean().item() - return ppo_policy_loss(policy_loss, entropy_loss), ppo_info(approx_kl, clipfrac) + + logr = logp_old - logp_new + if kl_type == 'k1': + kl_div = logr.mean() + elif kl_type == 'k2': + kl_div = (logr ** 2 / 2).mean() + elif kl_type == 'k3': + kl_div = (torch.exp(-logr) - 1 + logr).mean() + else: + raise ValueError(f"Unknown kl_type: {kl_type}") + + return ppo_policy_loss(policy_loss, entropy_loss), ppo_info(approx_kl, clipfrac, kl_div) diff --git a/dizoo/atari/config/serial/pong/pong_ppo_config.py b/dizoo/atari/config/serial/pong/pong_ppo_config.py index df74a30a55..c501f47709 100644 --- a/dizoo/atari/config/serial/pong/pong_ppo_config.py +++ b/dizoo/atari/config/serial/pong/pong_ppo_config.py @@ -39,6 +39,8 @@ ignore_done=False, grad_clip_type='clip_norm', grad_clip_value=0.5, + kl_beta=0.01, + kl_type='k1', ), collect=dict( n_sample=3200, diff --git a/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_onppo_config.py b/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_onppo_config.py index cb94b49e3b..77b7624194 100644 --- a/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_onppo_config.py +++ b/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_onppo_config.py @@ -44,6 +44,8 @@ ignore_done=False, grad_clip_type='clip_norm', grad_clip_value=0.5, + kl_beta=0.05, + kl_type='k1', ), collect=dict( n_sample=1024, From 37fb707e6301d285c984ae1d202d8206ebf1c912 Mon Sep 17 00:00:00 2001 From: wyx Date: Sun, 20 Jul 2025 08:38:45 +0000 Subject: [PATCH 2/6] fix bugs and add description for KL-divergence variants --- ding/policy/ppo.py | 25 +++++- ding/rl_utils/ppo.py | 85 ++++++++++++------- .../config/serial/pong/pong_ppo_config.py | 2 + .../spaceinvaders_onppo_config.py | 1 + 4 files changed, 79 insertions(+), 34 deletions(-) diff --git a/ding/policy/ppo.py b/ding/policy/ppo.py index 902c4f053f..5cd731c4da 100644 --- a/ding/policy/ppo.py +++ b/ding/policy/ppo.py @@ -77,9 +77,14 @@ class PPOPolicy(Policy): # (bool) Whether ignore done (usually for max step termination env). ignore_done=False, # (str) The type of KL divergence loss, ['k1', 'k2', 'k3'] + # http://joschu.net/blog/kl-approx.html kl_type='k1', # (float) The weight of KL divergence loss. kl_beta=0.0, + # (str or None) The path of pretrained model checkpoint. + # If provided, KL regularizer will be calculated between current policy and pretrained policy. + # Default to None, which means KL is not calculated. + pretrained_model_path=None, ), # collect_mode config collect=dict( @@ -190,6 +195,15 @@ def _init_learn(self) -> None: self._learn_model = model_wrap(self._model, wrapper_name='base') + # load pretrained model + if self._cfg.learn.pretrained_model_path is not None: + self._pretrained_model = copy.deepcopy(self._model) + state_dict = torch.load(self._cfg.learn.pretrained_model_path, map_location='cpu') + self._pretrained_model.load_state_dict(state_dict) + self._pretrained_model.eval() + else: + self._pretrained_model = None + # Algorithm config self._value_weight = self._cfg.learn.value_weight self._entropy_weight = self._cfg.learn.entropy_weight @@ -291,17 +305,23 @@ def _forward_learn(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: # Normalize advantage in a train_batch adv = (adv - adv.mean()) / (adv.std() + 1e-8) + if self._pretrained_model is not None: + with torch.no_grad(): + logit_pretrained = self._pretrained_model.forward(batch['obs'], mode='compute_actor')['logit'] + else: + logit_pretrained = None + # Calculate ppo error if self._action_space == 'continuous': ppo_batch = ppo_data( output['logit'], batch['logit'], batch['action'], output['value'], batch['value'], adv, - batch['return'], batch['weight'] + batch['return'], batch['weight'], logit_pretrained ) ppo_loss, ppo_info = ppo_error_continuous(ppo_batch, self._clip_ratio, kl_type=self._kl_type) elif self._action_space == 'discrete': ppo_batch = ppo_data( output['logit'], batch['logit'], batch['action'], output['value'], batch['value'], adv, - batch['return'], batch['weight'] + batch['return'], batch['weight'], logit_pretrained ) ppo_loss, ppo_info = ppo_error(ppo_batch, self._clip_ratio, kl_type=self._kl_type) elif self._action_space == 'hybrid': @@ -332,7 +352,6 @@ def _forward_learn(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: ) wv, we = self._value_weight, self._entropy_weight kl_div = ppo_info.kl_div - # 正确的、符合规范的修改 total_loss = ( ppo_loss.policy_loss + wv * ppo_loss.value_loss - we * ppo_loss.entropy_loss + self._kl_beta * kl_div diff --git a/ding/rl_utils/ppo.py b/ding/rl_utils/ppo.py index 57fb9c3c35..df8d69f048 100644 --- a/ding/rl_utils/ppo.py +++ b/ding/rl_utils/ppo.py @@ -6,13 +6,16 @@ from ding.hpc_rl import hpc_wrapper ppo_data = namedtuple( - 'ppo_data', ['logit_new', 'logit_old', 'action', 'value_new', 'value_old', 'adv', 'return_', 'weight'] + 'ppo_data', + ['logit_new', 'logit_old', 'action', 'value_new', 'value_old', 'adv', 'return_', 'weight', 'logit_pretrained'] ) ppo_data_continuous = namedtuple( 'ppo_data_continuous', ['mu_sigma_new', 'mu_sigma_old', 'action', 'value_new', 'value_old', 'adv', 'return_', 'weight'] ) -ppo_policy_data = namedtuple('ppo_policy_data', ['logit_new', 'logit_old', 'action', 'adv', 'weight']) +ppo_policy_data = namedtuple( + 'ppo_policy_data', ['logit_new', 'logit_old', 'action', 'adv', 'weight', 'logit_pretrained'] +) ppo_policy_data_continuous = namedtuple( 'ppo_policy_data_continuous', ['mu_sigma_new', 'mu_sigma_old', 'action', 'adv', 'weight'] ) @@ -22,6 +25,32 @@ ppo_info = namedtuple('ppo_info', ['approx_kl', 'clipfrac', 'kl_div']) +def calculate_kl_div(logr: torch.Tensor, kl_type: str) -> torch.Tensor: + """ + Overview: + Calculate different Monte-Carlo estimators for KL-divergence KL(q, p) = E_q[log(q/p)], + where q is the current policy and p is the pretrained policy. + The implementation is based on John Schulman's blog post "Approximating KL Divergence". + Reference: http://joschu.net/blog/kl-approx.html + Arguments: + - logr (:obj:`torch.Tensor`): The log-ratio of probabilities, which should be log(q/p) = logp_new - logp_pretrained. + - kl_type (:obj:`str`): The type of KL divergence estimator to use. + - 'k1': The standard, unbiased but high-variance estimator: `E_q[log(q/p)]`. + - 'k2': A biased, low-variance estimator from a second-order approximation: `E_q[1/2 * (log(p/q))^2]`. + - 'k3': An unbiased, low-variance estimator: `E_q[(p/q - 1) - log(p/q)]`. + Returns: + - kl_div (:obj:`torch.Tensor`): The calculated KL divergence estimate. + """ + if kl_type == 'k1': + return logr.mean() + elif kl_type == 'k2': + return (logr ** 2 / 2).mean() + elif kl_type == 'k3': + return (torch.exp(-logr) - 1 + logr).mean() + else: + raise ValueError(f"Unknown kl_type: {kl_type}") + + def shape_fn_ppo(args, kwargs): r""" Overview: @@ -97,8 +126,8 @@ def ppo_error( assert dual_clip is None or dual_clip > 1.0, "dual_clip value must be greater than 1.0, but get value: {}".format( dual_clip ) - logit_new, logit_old, action, value_new, value_old, adv, return_, weight = data - policy_data = ppo_policy_data(logit_new, logit_old, action, adv, weight) + logit_new, logit_old, action, value_new, value_old, adv, return_, weight, logit_pretrained = data + policy_data = ppo_policy_data(logit_new, logit_old, action, adv, weight, logit_pretrained) policy_output, policy_info = ppo_policy_error(policy_data, clip_ratio, dual_clip, kl_type=kl_type) value_data = ppo_value_data(value_new, value_old, return_, weight) value_loss = ppo_value_error(value_data, clip_ratio, use_value_clip) @@ -152,7 +181,7 @@ def ppo_policy_error( .. note:: For the action mask often used in LLM/VLM, users can set the `weight` to the action mask. """ - logit_new, logit_old, action, adv, weight = data + logit_new, logit_old, action, adv, weight, logit_pretrained = data if weight is None: weight = torch.ones_like(adv) dist_new = torch.distributions.categorical.Categorical(logits=logit_new) @@ -185,15 +214,13 @@ def ppo_policy_error( clipped = ratio.gt(1 + clip_ratio) | ratio.lt(1 - clip_ratio) clipfrac = torch.as_tensor(clipped).float().mean().item() - logr = logp_old - logp_new - if kl_type == 'k1': - kl_div = logr.mean() - elif kl_type == 'k2': - kl_div = (logr ** 2 / 2).mean() - elif kl_type == 'k3': - kl_div = (torch.exp(-logr) - 1 + logr).mean() + if logit_pretrained is not None: + dist_pretrained = torch.distributions.categorical.Categorical(logits=logit_pretrained) + logp_pretrained = dist_pretrained.log_prob(action) + logr = logp_new - logp_pretrained + kl_div = calculate_kl_div(logr, kl_type) else: - raise ValueError(f"Unknown kl_type: {kl_type}") + kl_div = 0 return ppo_policy_loss(policy_loss, entropy_loss), ppo_info(approx_kl, clipfrac, kl_div) @@ -298,7 +325,7 @@ def ppo_error_continuous( assert dual_clip is None or dual_clip > 1.0, "dual_clip value must be greater than 1.0, but get value: {}".format( dual_clip ) - mu_sigma_new, mu_sigma_old, action, value_new, value_old, adv, return_, weight = data + mu_sigma_new, mu_sigma_old, action, value_new, value_old, adv, return_, weight, logit_pretrained = data if weight is None: weight = torch.ones_like(adv) @@ -331,15 +358,13 @@ def ppo_error_continuous( else: value_loss = 0.5 * ((return_ - value_new).pow(2) * weight).mean() - logr = logp_old - logp_new - if kl_type == 'k1': - kl_div = logr.mean() - elif kl_type == 'k2': - kl_div = (logr ** 2 / 2).mean() - elif kl_type == 'k3': - kl_div = (torch.exp(-logr) - 1 + logr).mean() + if logit_pretrained is not None: + dist_pretrained = Independent(Normal(logit_pretrained['mu'], logit_pretrained['sigma']), 1) + logp_pretrained = dist_pretrained.log_prob(action) + logr = logp_new - logp_pretrained + kl_div = calculate_kl_div(logr, kl_type) else: - raise ValueError(f"Unknown kl_type: {kl_type}") + kl_div = 0 return ppo_loss(policy_loss, value_loss, entropy_loss), ppo_info(approx_kl, clipfrac, kl_div) @@ -384,7 +409,7 @@ def ppo_policy_error_continuous( assert dual_clip is None or dual_clip > 1.0, "dual_clip value must be greater than 1.0, but get value: {}".format( dual_clip ) - mu_sigma_new, mu_sigma_old, action, adv, weight = data + mu_sigma_new, mu_sigma_old, action, adv, weight, logit_pretrained = data if weight is None: weight = torch.ones_like(adv) @@ -409,14 +434,12 @@ def ppo_policy_error_continuous( clipped = ratio.gt(1 + clip_ratio) | ratio.lt(1 - clip_ratio) clipfrac = torch.as_tensor(clipped).float().mean().item() - logr = logp_old - logp_new - if kl_type == 'k1': - kl_div = logr.mean() - elif kl_type == 'k2': - kl_div = (logr ** 2 / 2).mean() - elif kl_type == 'k3': - kl_div = (torch.exp(-logr) - 1 + logr).mean() + if logit_pretrained is not None: + dist_pretrained = Independent(Normal(logit_pretrained['mu'], logit_pretrained['sigma']), 1) + logp_pretrained = dist_pretrained.log_prob(action) + logr = logp_new - logp_pretrained + kl_div = calculate_kl_div(logr, kl_type) else: - raise ValueError(f"Unknown kl_type: {kl_type}") + kl_div = 0 return ppo_policy_loss(policy_loss, entropy_loss), ppo_info(approx_kl, clipfrac, kl_div) diff --git a/dizoo/atari/config/serial/pong/pong_ppo_config.py b/dizoo/atari/config/serial/pong/pong_ppo_config.py index c501f47709..8541e29864 100644 --- a/dizoo/atari/config/serial/pong/pong_ppo_config.py +++ b/dizoo/atari/config/serial/pong/pong_ppo_config.py @@ -1,6 +1,7 @@ from easydict import EasyDict pong_ppo_config = dict( + exp_name='pong_ppo_seed0', env=dict( collector_env_num=8, evaluator_env_num=8, @@ -41,6 +42,7 @@ grad_clip_value=0.5, kl_beta=0.01, kl_type='k1', + pretrained_model_path='The path of your pretrained model', ), collect=dict( n_sample=3200, diff --git a/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_onppo_config.py b/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_onppo_config.py index 77b7624194..840e4c0ae7 100644 --- a/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_onppo_config.py +++ b/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_onppo_config.py @@ -46,6 +46,7 @@ grad_clip_value=0.5, kl_beta=0.05, kl_type='k1', + pretrained_model_path='The path of your pretrained model', ), collect=dict( n_sample=1024, From fc668c8fb5ffb2dadbcea5c1f8908f9ded07ed39 Mon Sep 17 00:00:00 2001 From: wyx Date: Sun, 20 Jul 2025 09:02:35 +0000 Subject: [PATCH 3/6] add a period for each comment line --- ding/policy/ppo.py | 4 ++-- ding/rl_utils/ppo.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/ding/policy/ppo.py b/ding/policy/ppo.py index 5cd731c4da..d1a94eaa32 100644 --- a/ding/policy/ppo.py +++ b/ding/policy/ppo.py @@ -76,8 +76,8 @@ class PPOPolicy(Policy): grad_clip_value=0.5, # (bool) Whether ignore done (usually for max step termination env). ignore_done=False, - # (str) The type of KL divergence loss, ['k1', 'k2', 'k3'] - # http://joschu.net/blog/kl-approx.html + # (str) The type of KL divergence loss, ['k1', 'k2', 'k3']. + # Reference: http://joschu.net/blog/kl-approx.html kl_type='k1', # (float) The weight of KL divergence loss. kl_beta=0.0, diff --git a/ding/rl_utils/ppo.py b/ding/rl_utils/ppo.py index df8d69f048..109d79e177 100644 --- a/ding/rl_utils/ppo.py +++ b/ding/rl_utils/ppo.py @@ -87,7 +87,7 @@ def ppo_error( - use_value_clip (:obj:`bool`): whether to use clip in value loss with the same ratio as policy - dual_clip (:obj:`float`): a parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf),\ defaults to 5.0, if you don't want to use it, set this parameter to None - - kl_type (:obj:`str`): which kl loss to use, default set to 'approx' + - kl_type (:obj:`str`): which kl loss to use, default set to 'approx'. Returns: - ppo_loss (:obj:`namedtuple`): the ppo loss item, all of them are the differentiable 0-dim tensor - ppo_info (:obj:`namedtuple`): the ppo optim information for monitoring, all of them are Python scalar @@ -151,7 +151,7 @@ def ppo_policy_error( - dual_clip (:obj:`float`): A parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf), \ defaults to 5.0, if you don't want to use it, set this parameter to None - entropy_bonus (:obj:`bool`): Whether to use entropy bonus, defaults to True. LLM RLHF usually does not use it. - - kl_type (:obj:`str`): which kl loss to use, default set to 'k1' + - kl_type (:obj:`str`): which kl loss to use, default set to 'k1'. Returns: - ppo_policy_loss (:obj:`namedtuple`): the ppo policy loss item, all of them are the differentiable 0-dim tensor - ppo_info (:obj:`namedtuple`): the ppo optim information for monitoring, all of them are Python scalar @@ -286,7 +286,7 @@ def ppo_error_continuous( - use_value_clip (:obj:`bool`): whether to use clip in value loss with the same ratio as policy - dual_clip (:obj:`float`): a parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf),\ defaults to 5.0, if you don't want to use it, set this parameter to None - - kl_type (:obj:`str`): which kl loss to use, default set to 'k1' + - kl_type (:obj:`str`): which kl loss to use, default set to 'k1'. Returns: - ppo_loss (:obj:`namedtuple`): the ppo loss item, all of them are the differentiable 0-dim tensor - ppo_info (:obj:`namedtuple`): the ppo optim information for monitoring, all of them are Python scalar @@ -383,7 +383,7 @@ def ppo_policy_error_continuous( - clip_ratio (:obj:`float`): the ppo clip ratio for the constraint of policy update, defaults to 0.2 - dual_clip (:obj:`float`): a parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf),\ defaults to 5.0, if you don't want to use it, set this parameter to None - - kl_type (:obj:`str`): which kl loss to use, default set to 'k1' + - kl_type (:obj:`str`): which kl loss to use, default set to 'k1'. Returns: - ppo_loss (:obj:`namedtuple`): the ppo loss item, all of them are the differentiable 0-dim tensor - ppo_info (:obj:`namedtuple`): the ppo optim information for monitoring, all of them are Python scalar From 64872286102c089e8bd71033fc457cac363523a4 Mon Sep 17 00:00:00 2001 From: wyx Date: Tue, 29 Jul 2025 03:10:16 +0000 Subject: [PATCH 4/6] fix bugs and add KL-divergence parameter descriptions --- ding/policy/ppo.py | 13 ++--- ding/rl_utils/ppo.py | 48 ++++++++++--------- .../config/serial/pong/pong_ppo_config.py | 5 +- .../spaceinvaders_onppo_config.py | 5 +- .../config/cartpole_dqn_ddp_config.py | 1 - .../halfcheetah_medium_expert_iql_config.py | 1 - .../config/halfcheetah_medium_iql_config.py | 1 - .../halfcheetah_medium_replay_iql_config.py | 1 - .../config/hopper_medium_expert_iql_config.py | 1 - dizoo/d4rl/config/hopper_medium_iql_config.py | 1 - .../config/hopper_medium_replay_iql_config.py | 1 - 11 files changed, 41 insertions(+), 37 deletions(-) diff --git a/ding/policy/ppo.py b/ding/policy/ppo.py index d1a94eaa32..3d7ee4b29c 100644 --- a/ding/policy/ppo.py +++ b/ding/policy/ppo.py @@ -76,12 +76,12 @@ class PPOPolicy(Policy): grad_clip_value=0.5, # (bool) Whether ignore done (usually for max step termination env). ignore_done=False, - # (str) The type of KL divergence loss, ['k1', 'k2', 'k3']. + # (str) The type of KL divergence loss between current policy and pretrained policy, ['k1', 'k2', 'k3']. # Reference: http://joschu.net/blog/kl-approx.html kl_type='k1', # (float) The weight of KL divergence loss. kl_beta=0.0, - # (str or None) The path of pretrained model checkpoint. + # (Optional[str]) The path of pretrained model checkpoint. # If provided, KL regularizer will be calculated between current policy and pretrained policy. # Default to None, which means KL is not calculated. pretrained_model_path=None, @@ -344,14 +344,14 @@ def _forward_learn(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: # sum discrete and continuous loss ppo_loss = type(ppo_continuous_loss)( ppo_continuous_loss.policy_loss + ppo_discrete_loss.policy_loss, ppo_continuous_loss.value_loss, - ppo_continuous_loss.entropy_loss + ppo_discrete_loss.entropy_loss + ppo_continuous_loss.entropy_loss + ppo_discrete_loss.entropy_loss, ppo_continuous_loss.kl_div ) ppo_info = type(ppo_continuous_info)( max(ppo_continuous_info.approx_kl, ppo_discrete_info.approx_kl), - max(ppo_continuous_info.clipfrac, ppo_discrete_info.clipfrac), ppo_continuous_info.kl_div + max(ppo_continuous_info.clipfrac, ppo_discrete_info.clipfrac) ) wv, we = self._value_weight, self._entropy_weight - kl_div = ppo_info.kl_div + kl_div = ppo_loss.kl_div total_loss = ( ppo_loss.policy_loss + wv * ppo_loss.value_loss - we * ppo_loss.entropy_loss + self._kl_beta * kl_div @@ -626,8 +626,9 @@ def _monitor_vars_learn(self) -> List[str]: 'clipfrac', 'value_max', 'value_mean', - 'kl_div', ] + if self._pretrained_model is not None: + variables += ['kl_div'] if self._action_space == 'continuous': variables += ['mu_mean', 'sigma_mean', 'sigma_grad', 'act'] return variables diff --git a/ding/rl_utils/ppo.py b/ding/rl_utils/ppo.py index 109d79e177..da8004e94e 100644 --- a/ding/rl_utils/ppo.py +++ b/ding/rl_utils/ppo.py @@ -10,22 +10,24 @@ ['logit_new', 'logit_old', 'action', 'value_new', 'value_old', 'adv', 'return_', 'weight', 'logit_pretrained'] ) ppo_data_continuous = namedtuple( - 'ppo_data_continuous', - ['mu_sigma_new', 'mu_sigma_old', 'action', 'value_new', 'value_old', 'adv', 'return_', 'weight'] + 'ppo_data_continuous', [ + 'mu_sigma_new', 'mu_sigma_old', 'action', 'value_new', 'value_old', 'adv', 'return_', 'weight', + 'logit_pretrained' + ] ) ppo_policy_data = namedtuple( 'ppo_policy_data', ['logit_new', 'logit_old', 'action', 'adv', 'weight', 'logit_pretrained'] ) ppo_policy_data_continuous = namedtuple( - 'ppo_policy_data_continuous', ['mu_sigma_new', 'mu_sigma_old', 'action', 'adv', 'weight'] + 'ppo_policy_data_continuous', ['mu_sigma_new', 'mu_sigma_old', 'action', 'adv', 'weight', 'logit_pretrained'] ) ppo_value_data = namedtuple('ppo_value_data', ['value_new', 'value_old', 'return_', 'weight']) -ppo_loss = namedtuple('ppo_loss', ['policy_loss', 'value_loss', 'entropy_loss']) -ppo_policy_loss = namedtuple('ppo_policy_loss', ['policy_loss', 'entropy_loss']) -ppo_info = namedtuple('ppo_info', ['approx_kl', 'clipfrac', 'kl_div']) +ppo_loss = namedtuple('ppo_loss', ['policy_loss', 'value_loss', 'entropy_loss', 'kl_div']) +ppo_policy_loss = namedtuple('ppo_policy_loss', ['policy_loss', 'entropy_loss', 'kl_div']) +ppo_info = namedtuple('ppo_info', ['approx_kl', 'clipfrac']) -def calculate_kl_div(logr: torch.Tensor, kl_type: str) -> torch.Tensor: +def calculate_kl_div(log_ratio: torch.Tensor, kl_type: str) -> torch.Tensor: """ Overview: Calculate different Monte-Carlo estimators for KL-divergence KL(q, p) = E_q[log(q/p)], @@ -33,7 +35,7 @@ def calculate_kl_div(logr: torch.Tensor, kl_type: str) -> torch.Tensor: The implementation is based on John Schulman's blog post "Approximating KL Divergence". Reference: http://joschu.net/blog/kl-approx.html Arguments: - - logr (:obj:`torch.Tensor`): The log-ratio of probabilities, which should be log(q/p) = logp_new - logp_pretrained. + - log_ratio (:obj:`torch.Tensor`): The log-ratio of probabilities, which should be log(q/p) = logp_new - logp_pretrained. - kl_type (:obj:`str`): The type of KL divergence estimator to use. - 'k1': The standard, unbiased but high-variance estimator: `E_q[log(q/p)]`. - 'k2': A biased, low-variance estimator from a second-order approximation: `E_q[1/2 * (log(p/q))^2]`. @@ -42,11 +44,11 @@ def calculate_kl_div(logr: torch.Tensor, kl_type: str) -> torch.Tensor: - kl_div (:obj:`torch.Tensor`): The calculated KL divergence estimate. """ if kl_type == 'k1': - return logr.mean() + return log_ratio.mean() elif kl_type == 'k2': - return (logr ** 2 / 2).mean() + return (log_ratio ** 2 / 2).mean() elif kl_type == 'k3': - return (torch.exp(-logr) - 1 + logr).mean() + return (torch.exp(-log_ratio) - 1 + log_ratio).mean() else: raise ValueError(f"Unknown kl_type: {kl_type}") @@ -87,7 +89,7 @@ def ppo_error( - use_value_clip (:obj:`bool`): whether to use clip in value loss with the same ratio as policy - dual_clip (:obj:`float`): a parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf),\ defaults to 5.0, if you don't want to use it, set this parameter to None - - kl_type (:obj:`str`): which kl loss to use, default set to 'approx'. + - kl_type (:obj:`str`): which kl loss to use, default set to 'k1'. Returns: - ppo_loss (:obj:`namedtuple`): the ppo loss item, all of them are the differentiable 0-dim tensor - ppo_info (:obj:`namedtuple`): the ppo optim information for monitoring, all of them are Python scalar @@ -132,7 +134,9 @@ def ppo_error( value_data = ppo_value_data(value_new, value_old, return_, weight) value_loss = ppo_value_error(value_data, clip_ratio, use_value_clip) - return ppo_loss(policy_output.policy_loss, value_loss, policy_output.entropy_loss), policy_info + return ppo_loss( + policy_output.policy_loss, value_loss, policy_output.entropy_loss, policy_output.kl_div + ), policy_info def ppo_policy_error( @@ -217,12 +221,12 @@ def ppo_policy_error( if logit_pretrained is not None: dist_pretrained = torch.distributions.categorical.Categorical(logits=logit_pretrained) logp_pretrained = dist_pretrained.log_prob(action) - logr = logp_new - logp_pretrained - kl_div = calculate_kl_div(logr, kl_type) + log_ratio = logp_new - logp_pretrained + kl_div = calculate_kl_div(log_ratio, kl_type) else: kl_div = 0 - return ppo_policy_loss(policy_loss, entropy_loss), ppo_info(approx_kl, clipfrac, kl_div) + return ppo_policy_loss(policy_loss, entropy_loss, kl_div), ppo_info(approx_kl, clipfrac) def ppo_value_error( @@ -361,12 +365,12 @@ def ppo_error_continuous( if logit_pretrained is not None: dist_pretrained = Independent(Normal(logit_pretrained['mu'], logit_pretrained['sigma']), 1) logp_pretrained = dist_pretrained.log_prob(action) - logr = logp_new - logp_pretrained - kl_div = calculate_kl_div(logr, kl_type) + log_ratio = logp_new - logp_pretrained + kl_div = calculate_kl_div(log_ratio, kl_type) else: kl_div = 0 - return ppo_loss(policy_loss, value_loss, entropy_loss), ppo_info(approx_kl, clipfrac, kl_div) + return ppo_loss(policy_loss, value_loss, entropy_loss, kl_div), ppo_info(approx_kl, clipfrac) def ppo_policy_error_continuous( @@ -437,9 +441,9 @@ def ppo_policy_error_continuous( if logit_pretrained is not None: dist_pretrained = Independent(Normal(logit_pretrained['mu'], logit_pretrained['sigma']), 1) logp_pretrained = dist_pretrained.log_prob(action) - logr = logp_new - logp_pretrained - kl_div = calculate_kl_div(logr, kl_type) + log_ratio = logp_new - logp_pretrained + kl_div = calculate_kl_div(log_ratio, kl_type) else: kl_div = 0 - return ppo_policy_loss(policy_loss, entropy_loss), ppo_info(approx_kl, clipfrac, kl_div) + return ppo_policy_loss(policy_loss, entropy_loss, kl_div), ppo_info(approx_kl, clipfrac) diff --git a/dizoo/atari/config/serial/pong/pong_ppo_config.py b/dizoo/atari/config/serial/pong/pong_ppo_config.py index 8541e29864..0d7ae3ed78 100644 --- a/dizoo/atari/config/serial/pong/pong_ppo_config.py +++ b/dizoo/atari/config/serial/pong/pong_ppo_config.py @@ -40,9 +40,12 @@ ignore_done=False, grad_clip_type='clip_norm', grad_clip_value=0.5, + # KL divergence regularization between current policy and pretrained policy. + # Supported KL divergence estimators: ['k1', 'k2', 'k3']. + # KL divergence loss will be calculated only when pretrained_model_path is provided. kl_beta=0.01, kl_type='k1', - pretrained_model_path='The path of your pretrained model', + pretrained_model_path=None, ), collect=dict( n_sample=3200, diff --git a/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_onppo_config.py b/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_onppo_config.py index 840e4c0ae7..fb5969282c 100644 --- a/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_onppo_config.py +++ b/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_onppo_config.py @@ -44,9 +44,12 @@ ignore_done=False, grad_clip_type='clip_norm', grad_clip_value=0.5, + # KL divergence regularization between current policy and pretrained policy. + # Supported KL divergence estimators: ['k1', 'k2', 'k3']. + # KL divergence loss will be calculated only when pretrained_model_path is provided. kl_beta=0.05, kl_type='k1', - pretrained_model_path='The path of your pretrained model', + pretrained_model_path=None, ), collect=dict( n_sample=1024, diff --git a/dizoo/classic_control/cartpole/config/cartpole_dqn_ddp_config.py b/dizoo/classic_control/cartpole/config/cartpole_dqn_ddp_config.py index 82d6c673ec..a80662941a 100644 --- a/dizoo/classic_control/cartpole/config/cartpole_dqn_ddp_config.py +++ b/dizoo/classic_control/cartpole/config/cartpole_dqn_ddp_config.py @@ -63,4 +63,3 @@ from ding.entry import serial_pipeline with DDPContext(): serial_pipeline((main_config, create_config), seed=0) - diff --git a/dizoo/d4rl/config/halfcheetah_medium_expert_iql_config.py b/dizoo/d4rl/config/halfcheetah_medium_expert_iql_config.py index 144feac1dd..e3aa855afe 100644 --- a/dizoo/d4rl/config/halfcheetah_medium_expert_iql_config.py +++ b/dizoo/d4rl/config/halfcheetah_medium_expert_iql_config.py @@ -18,7 +18,6 @@ model=dict( obs_shape=17, action_shape=6, - ), learn=dict( data_path=None, diff --git a/dizoo/d4rl/config/halfcheetah_medium_iql_config.py b/dizoo/d4rl/config/halfcheetah_medium_iql_config.py index 545ecf970b..440525a320 100644 --- a/dizoo/d4rl/config/halfcheetah_medium_iql_config.py +++ b/dizoo/d4rl/config/halfcheetah_medium_iql_config.py @@ -18,7 +18,6 @@ model=dict( obs_shape=17, action_shape=6, - ), learn=dict( data_path=None, diff --git a/dizoo/d4rl/config/halfcheetah_medium_replay_iql_config.py b/dizoo/d4rl/config/halfcheetah_medium_replay_iql_config.py index d48a1fb472..0974735b72 100644 --- a/dizoo/d4rl/config/halfcheetah_medium_replay_iql_config.py +++ b/dizoo/d4rl/config/halfcheetah_medium_replay_iql_config.py @@ -18,7 +18,6 @@ model=dict( obs_shape=17, action_shape=6, - ), learn=dict( data_path=None, diff --git a/dizoo/d4rl/config/hopper_medium_expert_iql_config.py b/dizoo/d4rl/config/hopper_medium_expert_iql_config.py index 6aef029c5e..2eebce2771 100644 --- a/dizoo/d4rl/config/hopper_medium_expert_iql_config.py +++ b/dizoo/d4rl/config/hopper_medium_expert_iql_config.py @@ -18,7 +18,6 @@ model=dict( obs_shape=11, action_shape=3, - ), learn=dict( data_path=None, diff --git a/dizoo/d4rl/config/hopper_medium_iql_config.py b/dizoo/d4rl/config/hopper_medium_iql_config.py index 8f429be268..61dbb5fac3 100644 --- a/dizoo/d4rl/config/hopper_medium_iql_config.py +++ b/dizoo/d4rl/config/hopper_medium_iql_config.py @@ -18,7 +18,6 @@ model=dict( obs_shape=11, action_shape=3, - ), learn=dict( data_path=None, diff --git a/dizoo/d4rl/config/hopper_medium_replay_iql_config.py b/dizoo/d4rl/config/hopper_medium_replay_iql_config.py index ad1b222843..df96a84aea 100644 --- a/dizoo/d4rl/config/hopper_medium_replay_iql_config.py +++ b/dizoo/d4rl/config/hopper_medium_replay_iql_config.py @@ -18,7 +18,6 @@ model=dict( obs_shape=11, action_shape=3, - ), learn=dict( data_path=None, From 4a919f8c5fb5557530731af5b0b4a5cd9bdcce1b Mon Sep 17 00:00:00 2001 From: wyx Date: Tue, 29 Jul 2025 06:18:08 +0000 Subject: [PATCH 5/6] fix flake8 E501 error in ppo.py docstring --- ding/rl_utils/ppo.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ding/rl_utils/ppo.py b/ding/rl_utils/ppo.py index da8004e94e..3709132edc 100644 --- a/ding/rl_utils/ppo.py +++ b/ding/rl_utils/ppo.py @@ -35,7 +35,8 @@ def calculate_kl_div(log_ratio: torch.Tensor, kl_type: str) -> torch.Tensor: The implementation is based on John Schulman's blog post "Approximating KL Divergence". Reference: http://joschu.net/blog/kl-approx.html Arguments: - - log_ratio (:obj:`torch.Tensor`): The log-ratio of probabilities, which should be log(q/p) = logp_new - logp_pretrained. + - log_ratio (:obj:`torch.Tensor`): The log-ratio of probabilities, which should be + log(q/p) = logp_new - logp_pretrained. - kl_type (:obj:`str`): The type of KL divergence estimator to use. - 'k1': The standard, unbiased but high-variance estimator: `E_q[log(q/p)]`. - 'k2': A biased, low-variance estimator from a second-order approximation: `E_q[1/2 * (log(p/q))^2]`. From e835c7947421f5fbe18a9ee959cbe38a77f13aff Mon Sep 17 00:00:00 2001 From: wyx Date: Tue, 29 Jul 2025 06:31:48 +0000 Subject: [PATCH 6/6] fix trailing whitespace in ppo.py --- ding/rl_utils/ppo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ding/rl_utils/ppo.py b/ding/rl_utils/ppo.py index 3709132edc..bb83d3b9bc 100644 --- a/ding/rl_utils/ppo.py +++ b/ding/rl_utils/ppo.py @@ -35,7 +35,7 @@ def calculate_kl_div(log_ratio: torch.Tensor, kl_type: str) -> torch.Tensor: The implementation is based on John Schulman's blog post "Approximating KL Divergence". Reference: http://joschu.net/blog/kl-approx.html Arguments: - - log_ratio (:obj:`torch.Tensor`): The log-ratio of probabilities, which should be + - log_ratio (:obj:`torch.Tensor`): The log-ratio of probabilities, which should be log(q/p) = logp_new - logp_pretrained. - kl_type (:obj:`str`): The type of KL divergence estimator to use. - 'k1': The standard, unbiased but high-variance estimator: `E_q[log(q/p)]`.