2828
2929import verl .utils .torch_functional as verl_F
3030from verl import DataProto
31- from verl .trainer .ppo .core_algos import agg_loss , compute_policy_loss , compute_sft_loss , get_policy_loss_fn , kl_penalty
31+ from verl .trainer .ppo .core_algos import agg_loss , compute_policy_loss , get_policy_loss_fn , kl_penalty
3232from verl .utils .debug import GPUMemoryLogger
3333from verl .utils .device import get_device_id , get_device_name , is_cuda_available , is_npu_available
3434from verl .utils .fsdp_utils import FSDPModule , fsdp2_clip_grad_norm_
@@ -79,16 +79,13 @@ def __init__(self, config, actor_module: nn.Module, actor_optimizer: torch.optim
7979 )
8080 self .device_name = get_device_name ()
8181
82- def _forward_micro_batch (self , micro_batch , temperature , compute_teacher , calculate_entropy = False ) -> Tuple [torch .Tensor , torch .Tensor ]:
82+ def _forward_micro_batch (self , micro_batch , temperature , calculate_entropy = False ) -> Tuple [torch .Tensor , torch .Tensor ]:
8383 """
8484 Returns:
8585 entropy: # (bs, response_len)
8686 log_probs: # (bs, response_len)
8787 """
88- if compute_teacher :
89- response_length = micro_batch ["teacher_response" ].size (- 1 )
90- else :
91- response_length = micro_batch ["responses" ].size (- 1 )
88+ response_length = micro_batch ["responses" ].size (- 1 )
9289 multi_modal_inputs = {}
9390 if "multi_modal_inputs" in micro_batch .keys ():
9491 for key in micro_batch ["multi_modal_inputs" ][0 ].keys ():
@@ -101,16 +98,10 @@ def _forward_micro_batch(self, micro_batch, temperature, compute_teacher, calcul
10198 multi_modal_inputs [key ] = torch .cat ([inputs [key ] for inputs in micro_batch ["multi_modal_inputs" ]], dim = 0 )
10299
103100 with torch .autocast (device_type = self .device_name , dtype = torch .bfloat16 ):
104- if compute_teacher :
105- input_ids = micro_batch ["teacher_input_ids" ]
106- batch_size , seqlen = input_ids .shape
107- attention_mask = micro_batch ["teacher_attention_mask" ]
108- position_ids = micro_batch ["teacher_position_ids" ]
109- else :
110- input_ids = micro_batch ["input_ids" ]
111- batch_size , seqlen = input_ids .shape
112- attention_mask = micro_batch ["attention_mask" ]
113- position_ids = micro_batch ["position_ids" ]
101+ input_ids = micro_batch ["input_ids" ]
102+ batch_size , seqlen = input_ids .shape
103+ attention_mask = micro_batch ["attention_mask" ]
104+ position_ids = micro_batch ["position_ids" ]
114105 entropy = None
115106 if position_ids .dim () == 3 : # qwen2vl mrope
116107 position_ids = position_ids .transpose (0 , 1 ) # (bsz, 3, seqlen) -> (3, bsz, seqlen)
@@ -315,7 +306,6 @@ def compute_log_prob(self, data: DataProto, calculate_entropy=False) -> torch.Te
315306 Returns:
316307 torch.Tensor: the log_prob tensor
317308 """
318- compute_teacher = data .meta_info ["compute_teacher" ]
319309 # set to eval
320310 self .actor_module .eval ()
321311
@@ -324,10 +314,7 @@ def compute_log_prob(self, data: DataProto, calculate_entropy=False) -> torch.Te
324314 use_dynamic_bsz = data .meta_info ["use_dynamic_bsz" ]
325315
326316 def _get_micro_batches (data : DataProto ) -> Tuple [list , list | None ]:
327- if compute_teacher :
328- select_keys = ["teacher_response" , "teacher_input_ids" , "teacher_attention_mask" , "teacher_position_ids" ]
329- else :
330- select_keys = ["responses" , "input_ids" , "attention_mask" , "position_ids" ]
317+ select_keys = ["responses" , "input_ids" , "attention_mask" , "position_ids" ]
331318 batch = data .select (batch_keys = select_keys ).batch
332319 has_multi_modal_inputs = "multi_modal_inputs" in data .non_tensor_batch
333320
@@ -352,7 +339,7 @@ def _get_micro_batches(data: DataProto) -> Tuple[list, list | None]:
352339 return micro_batches_dp , None
353340 elif use_dynamic_bsz :
354341 max_token_len = data .meta_info ["max_token_len" ] * self .ulysses_sequence_parallel_size
355- micro_batches , indices = rearrange_micro_batches (batch = batch , max_token_len = max_token_len , compute_teacher = compute_teacher )
342+ micro_batches , indices = rearrange_micro_batches (batch = batch , max_token_len = max_token_len )
356343 return micro_batches , indices
357344 else :
358345 micro_batches = batch .split (micro_batch_size )
@@ -366,7 +353,7 @@ def _get_micro_batches(data: DataProto) -> Tuple[list, list | None]:
366353 if isinstance (micro_batch , DataProto ):
367354 micro_batch = {** micro_batch .batch , ** micro_batch .non_tensor_batch }
368355 with torch .no_grad ():
369- entropy , log_probs = self ._forward_micro_batch (micro_batch , compute_teacher = compute_teacher , temperature = temperature , calculate_entropy = calculate_entropy )
356+ entropy , log_probs = self ._forward_micro_batch (micro_batch , temperature = temperature , calculate_entropy = calculate_entropy )
370357 log_probs_lst .append (log_probs )
371358 if calculate_entropy :
372359 entropy_lst .append (entropy )
@@ -387,17 +374,13 @@ def _get_micro_batches(data: DataProto) -> Tuple[list, list | None]:
387374
388375 @GPUMemoryLogger (role = "dp actor" , logger = logger )
389376 def update_policy (self , data : DataProto ):
390-
391377 # make sure we are in training mode
392378 self .actor_module .train ()
393379
394380 temperature = data .meta_info ["temperature" ] # temperature must be in the data.meta_info to avoid silent error
395381 multi_turn = data .meta_info .get ("multi_turn" , False )
396382
397- select_keys = [
398- "responses" , "input_ids" , "attention_mask" , "position_ids" , "old_log_probs" , "advantages" ,
399- "teacher_response" , "teacher_input_ids" , "teacher_attention_mask" , "teacher_position_ids"
400- ]
383+ select_keys = ["responses" , "input_ids" , "attention_mask" , "position_ids" , "old_log_probs" , "advantages" ]
401384 if multi_turn :
402385 select_keys .append ("loss_mask" )
403386 if self .config .use_kl_loss :
@@ -439,7 +422,7 @@ def update_policy(self, data: DataProto):
439422 micro_batches = data .select (select_keys , non_tensor_select_keys ).chunk (num_micro_batches )
440423 elif self .config .use_dynamic_bsz :
441424 max_token_len = self .config .ppo_max_token_len_per_gpu * self .ulysses_sequence_parallel_size
442- micro_batches , _ = rearrange_micro_batches (batch = mini_batch , max_token_len = max_token_len , compute_teacher = False )
425+ micro_batches , _ = rearrange_micro_batches (batch = mini_batch , max_token_len = max_token_len )
443426 else :
444427 self .gradient_accumulation = self .config .ppo_mini_batch_size // self .config .ppo_micro_batch_size_per_gpu
445428 # split batch into micro_batches
@@ -462,16 +445,12 @@ def update_policy(self, data: DataProto):
462445 else :
463446 data = data .to (get_device_id ()) # actor device is cpu when using offload
464447 responses = data ["responses" ]
465- teacher_response = data ["teacher_response" ]
466448 response_length = responses .size (1 )
467- teacher_response_length = teacher_response .size (1 )
468449 attention_mask = data ["attention_mask" ]
469- teacher_attention_mask = data ["teacher_attention_mask" ]
470450 if multi_turn :
471451 response_mask = data ["loss_mask" ][:, - response_length :]
472452 else :
473453 response_mask = attention_mask [:, - response_length :]
474- teacher_response_mask = teacher_attention_mask [:, - teacher_response_length :]
475454
476455 old_log_prob = data ["old_log_probs" ]
477456 advantages = data ["advantages" ]
@@ -487,17 +466,22 @@ def update_policy(self, data: DataProto):
487466 calculate_entropy = False
488467 if entropy_coeff != 0 :
489468 calculate_entropy = True
490- teacher_entropy , teacher_log_prob = self ._forward_micro_batch (micro_batch = data , compute_teacher = True , temperature = temperature , calculate_entropy = calculate_entropy )
469+ entropy , log_prob = self ._forward_micro_batch (micro_batch = data , temperature = temperature , calculate_entropy = calculate_entropy )
491470
492471 loss_mode = self .config .policy_loss .get ("loss_mode" , "vanilla" )
493472
494473 if self .config .policy_loss .loss_mode == "vanilla" :
495- teacher_pg_loss = compute_sft_loss (
496- log_prob = teacher_log_prob ,
497- response_mask = teacher_response_mask ,
474+ pg_loss , pg_clipfrac , ppo_kl , pg_clipfrac_lower = compute_policy_loss (
475+ old_log_prob = old_log_prob ,
476+ log_prob = log_prob ,
477+ advantages = advantages ,
478+ response_mask = response_mask ,
479+ cliprange = clip_ratio ,
480+ cliprange_low = clip_ratio_low ,
481+ cliprange_high = clip_ratio_high ,
482+ clip_ratio_c = clip_ratio_c ,
498483 loss_agg_mode = loss_agg_mode ,
499484 )
500- pg_loss = teacher_pg_loss
501485 else :
502486 policy_loss_fn = get_policy_loss_fn (loss_mode )
503487 pg_loss , pg_clipfrac , ppo_kl , pg_clipfrac_lower = policy_loss_fn (old_log_prob , log_prob , advantages , response_mask , loss_agg_mode , self .config )
@@ -510,6 +494,16 @@ def update_policy(self, data: DataProto):
510494 else :
511495 policy_loss = pg_loss
512496
497+ if self .config .use_kl_loss :
498+ ref_log_prob = data ["ref_log_prob" ]
499+ # compute kl loss
500+ kld = kl_penalty (logprob = log_prob , ref_logprob = ref_log_prob , kl_penalty = self .config .kl_loss_type )
501+ kl_loss = agg_loss (loss_mat = kld , loss_mask = response_mask , loss_agg_mode = loss_agg_mode )
502+
503+ policy_loss = policy_loss + kl_loss * self .config .kl_loss_coef
504+ metrics ["actor/kl_loss" ] = kl_loss .detach ().item ()
505+ metrics ["actor/kl_coef" ] = self .config .kl_loss_coef
506+
513507 if self .config .use_dynamic_bsz :
514508 # relative to the dynamic bsz
515509 loss = policy_loss * (len (data ) / self .config .ppo_mini_batch_size )
@@ -519,7 +513,9 @@ def update_policy(self, data: DataProto):
519513
520514 data = {
521515 "actor/pg_loss" : pg_loss .detach ().item (),
522- "actor/teacher_pg_loss" : teacher_pg_loss .detach ().item (),
516+ "actor/pg_clipfrac" : pg_clipfrac .detach ().item (),
517+ "actor/ppo_kl" : ppo_kl .detach ().item (),
518+ "actor/pg_clipfrac_lower" : pg_clipfrac_lower .detach ().item (),
523519 }
524520 append_to_dict (metrics , data )
525521
0 commit comments