Skip to content

Commit ccf59ec

Browse files
samsjaclaude
andcommitted
Rename to squared_kl
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 63403ac commit ccf59ec

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

src/prime_rl/trainer/rl/loss.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def compute_loss(
112112
total_geo_masked_high = []
113113
total_geo_seq_ratio = []
114114
total_teacher_kl = []
115-
total_kl_loss = []
115+
total_squared_kl = []
116116

117117
if teacher_logprobs is None:
118118
teacher_logprobs = [None] * len(trainer_logprobs)
@@ -151,8 +151,8 @@ def compute_loss(
151151
if teacher_logprobs is not None:
152152
advantages = advantages + loss_config.teacher_tau * teacher_kl.detach()
153153

154-
# KL loss: (log π_θ/π_old)² per token
155-
kl_loss_metric = _safe_mean(log_importance_ratio**2, loss_mask)
154+
# Squared KL: (log π_θ/π_old)² per token
155+
squared_kl = _safe_mean(log_importance_ratio**2, loss_mask)
156156

157157
if loss_config.kl_loss_type == "k2":
158158
# Kimi K2 style: direct squared loss (advantages - τ·log_ratio)²
@@ -182,7 +182,7 @@ def compute_loss(
182182
total_geo_masked_low.append(geo_mask_low.float())
183183
total_geo_masked_high.append(geo_mask_high.float())
184184
total_geo_seq_ratio.append(geo_seq_ratio)
185-
total_kl_loss.append(kl_loss_metric)
185+
total_squared_kl.append(squared_kl)
186186
if teacher_logprobs is not None:
187187
total_teacher_kl.append(_safe_mean(teacher_kl, loss_mask))
188188

@@ -201,7 +201,7 @@ def compute_loss(
201201
"geo_masked_low": torch.stack(total_geo_masked_low),
202202
"geo_masked_high": torch.stack(total_geo_masked_high),
203203
"geo_seq_ratio": torch.stack(total_geo_seq_ratio),
204-
"kl_loss": torch.stack(total_kl_loss),
204+
"squared_kl": torch.stack(total_squared_kl),
205205
}
206206
if total_teacher_kl:
207207
result["teacher_kl"] = torch.stack(total_teacher_kl)

0 commit comments

Comments
 (0)