Skip to content

Commit 7b2701c

Browse files
authored
Tiny fix k1 estimator typo and low_var_kl comments (#483)
1 parent 46e2cd4 commit 7b2701c

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

slime/utils/arguments.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -615,9 +615,9 @@ def add_algo_arguments(parser):
615615
parser.add_argument(
616616
"--kl-loss-type",
617617
type=str,
618-
choices=["kl", "k2", "k3", "low_var_kl"],
619-
default="kl",
620-
help="Choose KL loss type: kl, k2, k3 low_var_kl",
618+
choices=["k1", "k2", "k3", "low_var_kl"],
619+
default="k1",
620+
help="Choose KL loss type: kl, k2, k3, low_var_kl",
621621
)
622622
parser.add_argument(
623623
"--advantage-estimator",

slime/utils/ppo_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,20 +24,20 @@ def compute_approx_kl(
2424

2525
log_ratio = log_probs.float() - log_probs_base.float()
2626

27-
if kl_loss_type == "kl":
27+
if kl_loss_type == "k1":
2828
return log_ratio
2929
elif kl_loss_type == "k2":
3030
log_ratio = log_probs.float() - log_probs_base.float()
3131
log_ratio = log_ratio**2 / 2.0
3232
return log_ratio
3333
elif kl_loss_type == "k3":
34+
# The non negative kl approximation in
35+
# http://joschu.net/blog/kl-approx.html
36+
# Besides non negative, it is also unbiased and have lower variance.
3437
log_ratio = -log_ratio
3538
log_ratio = log_ratio.exp() - 1 - log_ratio
3639
return log_ratio
3740
elif kl_loss_type == "low_var_kl":
38-
# The non negative kl approximation in
39-
# http://joschu.net/blog/kl-approx.html
40-
# Besides non negative, it is also unbiased and have lower variance.
4141
log_ratio = -log_ratio
4242
log_ratio = log_ratio.exp() - 1 - log_ratio
4343
return torch.clamp(log_ratio, min=-10, max=10)

0 commit comments

Comments
 (0)