@@ -133,7 +133,6 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]:
133
133
response_length = torch .sum (action_mask , dim = 1 ).to (torch .float32 )
134
134
135
135
need_update = (step_idx + 1 ) % self .num_microbatches == 0
136
-
137
136
ctx = nullcontext () if need_update else self .booster .no_sync (self .policy_model , self .optimizer )
138
137
with ctx :
139
138
policy_model_logits = self .policy_model (
@@ -243,13 +242,15 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]:
243
242
)
244
243
self .wandb_run .log (
245
244
{
245
+ "metrics/reward" : self .accum_reward .item () / self .accum_count ,
246
+ "metrics/format_reward" : self .accum_format_reward .item () / self .accum_count ,
247
+ "metrics/acc_reward" : self .accum_acc_reward .item () / self .accum_count ,
248
+ "metrics/response_length" : self .accum_response_length .item () / self .accum_count ,
246
249
"train/loss" : self .accum_loss .item () / self .accum_count ,
247
- "train/reward" : self .accum_reward .item () / self .accum_count ,
248
- "train/format_reward" : self .accum_format_reward .item () / self .accum_count ,
249
- "train/acc_reward" : self .accum_acc_reward .item () / self .accum_count ,
250
250
"train/kl" : self .accum_kl .item () / self .accum_count ,
251
251
"train/advantages" : self .accum_advantages .item () / self .accum_count ,
252
- "train/response_length" : self .accum_response_length .item () / self .accum_count ,
252
+ "train/learning_rate" : self .lr_scheduler .get_last_lr ()[0 ],
253
+ "rollout/temperature" : data ["temperature" ].cpu ().numpy ()[0 ][0 ],
253
254
}
254
255
)
255
256
self .accum_loss .zero_ ()
0 commit comments