@@ -112,7 +112,7 @@ def compute_loss(
112112 total_geo_masked_high = []
113113 total_geo_seq_ratio = []
114114 total_teacher_kl = []
115- total_squared_kl = []
115+ total_kl_loss = []
116116
117117 if teacher_logprobs is None :
118118 teacher_logprobs = [None ] * len (trainer_logprobs )
@@ -151,8 +151,8 @@ def compute_loss(
151151 if teacher_logprobs is not None :
152152 advantages = advantages + loss_config .teacher_tau * teacher_kl .detach ()
153153
154- # Squared KL : (log π_θ/π_old)² per token
155- squared_kl = _safe_mean (log_importance_ratio ** 2 , loss_mask )
154+ # KL loss : (log π_θ/π_old)² per token
155+ kl_loss_metric = _safe_mean (log_importance_ratio ** 2 , loss_mask )
156156
157157 if loss_config .kl_loss_type == "k2" :
158158 # Kimi K2 style: direct squared loss (advantages - τ·log_ratio)²
@@ -182,7 +182,7 @@ def compute_loss(
182182 total_geo_masked_low .append (geo_mask_low .float ())
183183 total_geo_masked_high .append (geo_mask_high .float ())
184184 total_geo_seq_ratio .append (geo_seq_ratio )
185- total_squared_kl .append (squared_kl )
185+ total_kl_loss .append (kl_loss_metric )
186186 if teacher_logprobs is not None :
187187 total_teacher_kl .append (_safe_mean (teacher_kl , loss_mask ))
188188
@@ -201,7 +201,7 @@ def compute_loss(
201201 "geo_masked_low" : torch .stack (total_geo_masked_low ),
202202 "geo_masked_high" : torch .stack (total_geo_masked_high ),
203203 "geo_seq_ratio" : torch .stack (total_geo_seq_ratio ),
204- "squared_kl " : torch .stack (total_squared_kl ),
204+ "kl_loss " : torch .stack (total_kl_loss ),
205205 }
206206 if total_teacher_kl :
207207 result ["teacher_kl" ] = torch .stack (total_teacher_kl )
0 commit comments