Commit b470cbf
[training_utils] fix: RM extra scaling in KL/PG losses (verl-project#4711)
### What does this PR do?
The KL/ PG losses currently logged are scaled by the number of
micro-batches twice. The result is that the logged metrics represent the
mean value across micro-batches **scaled by the number of
micro-batches**. This PR only scales once so that the logged metrics
represent the mean across micro-batches with no extra scaling.
First scaling:
https://github.com/volcengine/verl/blob/cd4072daad2652794ecff0b5816a05afedff8608/verl/workers/actor/dp_actor.py#L533
Second scaling:
https://github.com/volcengine/verl/blob/cd4072daad2652794ecff0b5816a05afedff8608/verl/utils/metric/utils.py#L53
### Test
On `main`, decreasing micro-batch size from 8->2 decreases logged loss
by a factor of 4:
<img width="970" height="640" alt="image"
src="https://github.com/user-attachments/assets/9d6cf0a5-1cef-46ad-9d4b-c1d1d56a9af7"
/>
Decreasing micro-batch size on this branch does not effect metric
magnitude:
<img width="988" height="644" alt="image"
src="https://github.com/user-attachments/assets/c8f6bc34-da02-4469-8e16-58b53c6235a9"
/>
```bash
python -m verl.trainer.main_ppo \
algorithm.adv_estimator=grpo \
data.dataloader_num_workers=0 \
data.return_full_prompt=True \
data.train_files=$SAVE_PATH/gsm8k/train.parquet \
data.val_files=$SAVE_PATH/gsm8k/test.parquet \
data.train_batch_size=8 \
data.max_prompt_length=512 \
data.max_response_length=1024 \
data.filter_overlong_prompts=True \
data.truncation='error' \
actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \
+actor_rollout_ref.ref.model.path=Qwen/Qwen2.5-3B-Instruct \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.ppo_mini_batch_size=8 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \
actor_rollout_ref.actor.use_kl_loss=True \
actor_rollout_ref.actor.kl_loss_coef=10 \
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
actor_rollout_ref.rollout.n=5 \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=8 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
algorithm.use_kl_in_reward=False \
trainer.critic_warmup=0 \
trainer.logger='["console","wandb"]' \
trainer.project_name='verl_fix_metrics' \
trainer.experiment_name='NEW/ppo_micro_batch_size_per_gpu2' \
trainer.n_gpus_per_node=1 \
trainer.nnodes=1 \
trainer.save_freq=20 \
trainer.test_freq=5 \
trainer.resume_mode="disable" \
trainer.total_epochs=15 \
actor_rollout_ref.actor.use_torch_compile=False \
actor_rollout_ref.actor.fsdp_config.use_torch_compile=False \
trainer.val_before_train=False \
actor_rollout_ref.rollout.enforce_eager=True \
actor_rollout_ref.ref.fsdp_config.use_torch_compile=False
```
### Design & Code Changes
RM scaling in `dp_actor`1 parent 0e9da5e commit b470cbf
2 files changed
+10
-5
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
431 | 431 | | |
432 | 432 | | |
433 | 433 | | |
434 | | - | |
| 434 | + | |
| 435 | + | |
| 436 | + | |
| 437 | + | |
435 | 438 | | |
436 | 439 | | |
437 | 440 | | |
| |||
530 | 533 | | |
531 | 534 | | |
532 | 535 | | |
533 | | - | |
| 536 | + | |
534 | 537 | | |
535 | 538 | | |
536 | 539 | | |
| |||
543 | 546 | | |
544 | 547 | | |
545 | 548 | | |
546 | | - | |
| 549 | + | |
547 | 550 | | |
548 | 551 | | |
549 | 552 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
192 | 192 | | |
193 | 193 | | |
194 | 194 | | |
195 | | - | |
| 195 | + | |
| 196 | + | |
| 197 | + | |
196 | 198 | | |
197 | 199 | | |
198 | 200 | | |
| |||
246 | 248 | | |
247 | 249 | | |
248 | 250 | | |
249 | | - | |
250 | 251 | | |
251 | 252 | | |
252 | 253 | | |
253 | 254 | | |
254 | 255 | | |
| 256 | + | |
255 | 257 | | |
256 | 258 | | |
257 | 259 | | |
| |||
0 commit comments