Skip to content

Commit b470cbf

Browse files
JacobHelwigjsfanfanfan
authored andcommitted
[training_utils] fix: RM extra scaling in KL/PG losses (verl-project#4711)
### What does this PR do? The KL/ PG losses currently logged are scaled by the number of micro-batches twice. The result is that the logged metrics represent the mean value across micro-batches **scaled by the number of micro-batches**. This PR only scales once so that the logged metrics represent the mean across micro-batches with no extra scaling. First scaling: https://github.com/volcengine/verl/blob/cd4072daad2652794ecff0b5816a05afedff8608/verl/workers/actor/dp_actor.py#L533 Second scaling: https://github.com/volcengine/verl/blob/cd4072daad2652794ecff0b5816a05afedff8608/verl/utils/metric/utils.py#L53 ### Test On `main`, decreasing micro-batch size from 8->2 decreases logged loss by a factor of 4: <img width="970" height="640" alt="image" src="https://github.com/user-attachments/assets/9d6cf0a5-1cef-46ad-9d4b-c1d1d56a9af7" /> Decreasing micro-batch size on this branch does not effect metric magnitude: <img width="988" height="644" alt="image" src="https://github.com/user-attachments/assets/c8f6bc34-da02-4469-8e16-58b53c6235a9" /> ```bash python -m verl.trainer.main_ppo \ algorithm.adv_estimator=grpo \ data.dataloader_num_workers=0 \ data.return_full_prompt=True \ data.train_files=$SAVE_PATH/gsm8k/train.parquet \ data.val_files=$SAVE_PATH/gsm8k/test.parquet \ data.train_batch_size=8 \ data.max_prompt_length=512 \ data.max_response_length=1024 \ data.filter_overlong_prompts=True \ data.truncation='error' \ actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \ +actor_rollout_ref.ref.model.path=Qwen/Qwen2.5-3B-Instruct \ actor_rollout_ref.actor.optim.lr=1e-6 \ actor_rollout_ref.model.use_remove_padding=True \ actor_rollout_ref.actor.ppo_mini_batch_size=8 \ actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \ actor_rollout_ref.actor.use_kl_loss=True \ actor_rollout_ref.actor.kl_loss_coef=10 \ actor_rollout_ref.actor.kl_loss_type=low_var_kl \ actor_rollout_ref.actor.entropy_coeff=0 \ actor_rollout_ref.model.enable_gradient_checkpointing=True \ actor_rollout_ref.actor.fsdp_config.param_offload=False \ actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \ actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ actor_rollout_ref.rollout.name=vllm \ actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ actor_rollout_ref.rollout.n=5 \ actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=8 \ actor_rollout_ref.ref.fsdp_config.param_offload=True \ algorithm.use_kl_in_reward=False \ trainer.critic_warmup=0 \ trainer.logger='["console","wandb"]' \ trainer.project_name='verl_fix_metrics' \ trainer.experiment_name='NEW/ppo_micro_batch_size_per_gpu2' \ trainer.n_gpus_per_node=1 \ trainer.nnodes=1 \ trainer.save_freq=20 \ trainer.test_freq=5 \ trainer.resume_mode="disable" \ trainer.total_epochs=15 \ actor_rollout_ref.actor.use_torch_compile=False \ actor_rollout_ref.actor.fsdp_config.use_torch_compile=False \ trainer.val_before_train=False \ actor_rollout_ref.rollout.enforce_eager=True \ actor_rollout_ref.ref.fsdp_config.use_torch_compile=False ``` ### Design & Code Changes RM scaling in `dp_actor`
1 parent 0e9da5e commit b470cbf

File tree

2 files changed

+10
-5
lines changed

2 files changed

+10
-5
lines changed

verl/workers/actor/dp_actor.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,10 @@ def update_policy(self, data: DataProto):
431431

432432
on_policy = len(mini_batches) == 1 and self.config.ppo_epochs == 1
433433

434-
metrics = {}
434+
metrics = {
435+
"actor/pg_loss": 0.0,
436+
"actor/kl_loss": 0.0,
437+
}
435438
for _ in range(self.config.ppo_epochs):
436439
for batch_idx, mini_batch in enumerate(mini_batches):
437440
if self.config.use_dynamic_bsz:
@@ -530,7 +533,7 @@ def update_policy(self, data: DataProto):
530533
kl_loss = agg_loss(loss_mat=kld, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
531534

532535
policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef
533-
micro_batch_metrics["actor/kl_loss"] = kl_loss.detach().item() * loss_scale_factor
536+
metrics["actor/kl_loss"] += kl_loss.detach().item() * loss_scale_factor
534537
micro_batch_metrics["actor/kl_coef"] = self.config.kl_loss_coef
535538

536539
if self.config.use_dynamic_bsz:
@@ -543,7 +546,7 @@ def update_policy(self, data: DataProto):
543546
else:
544547
loss.backward()
545548

546-
micro_batch_metrics["actor/pg_loss"] = pg_loss.detach().item() * loss_scale_factor
549+
metrics["actor/pg_loss"] += pg_loss.detach().item() * loss_scale_factor
547550
append_to_dict(metrics, micro_batch_metrics)
548551

549552
grad_norm = self._optimizer_step()

verl/workers/critic/dp_critic.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,9 @@ def compute_values(self, data: DataProto) -> torch.Tensor:
192192
def update_critic(self, data: DataProto):
193193
# make sure we are in training mode
194194
self.critic_module.train()
195-
metrics = {}
195+
metrics = {
196+
"critic/vf_loss": 0.0,
197+
}
196198

197199
select_keys = ["input_ids", "responses", "response_mask", "attention_mask", "position_ids", "values", "returns"]
198200
has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys()
@@ -246,12 +248,12 @@ def update_critic(self, data: DataProto):
246248

247249
micro_batch_metrics.update(
248250
{
249-
"critic/vf_loss": vf_loss.detach().item() * loss_scale_factor,
250251
"critic/vf_clipfrac": vf_clipfrac.detach().item(),
251252
"critic/vpred_mean": masked_mean(vpreds, response_mask).detach().item(),
252253
}
253254
)
254255

256+
metrics["critic/vf_loss"] += vf_loss.detach().item() * loss_scale_factor
255257
append_to_dict(metrics, micro_batch_metrics)
256258

257259
grad_norm = self._optimizer_step()

0 commit comments

Comments
 (0)