@@ -224,8 +224,8 @@ def compute_advantage(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_re
224224 DataProto: The updated data with computed advantages and returns.
225225 """
226226 # Back-compatible with trainers that do not compute response mask in fit
227- if "teacher_response_mask " not in data .batch .keys ():
228- data .batch ["teacher_response_mask " ] = compute_response_mask (data , compute_teacher = True )
227+ if "response_mask " not in data .batch .keys ():
228+ data .batch ["response_mask " ] = compute_response_mask (data , compute_teacher = False )
229229 # prepare response group
230230 if adv_estimator == AdvantageEstimator .GAE :
231231 # Compute advantages and returns using Generalized Advantage Estimation (GAE)
@@ -245,16 +245,23 @@ def compute_advantage(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_re
245245 config .get ("pf_ppo_weight_pow" , 2.0 ),
246246 )
247247 elif adv_estimator == AdvantageEstimator .GRPO :
248- grpo_calculation_mask = data .batch ["teacher_response_mask" ]
248+ # Initialize the mask for GRPO calculation
249+ grpo_calculation_mask = data .batch ["response_mask" ]
250+ if multi_turn :
251+ # If multi-turn, replace the mask with the relevant part of loss_mask
252+ # Get length from the initial response mask
253+ response_length = grpo_calculation_mask .size (1 )
254+ # This mask is the one intended for GRPO
255+ grpo_calculation_mask = data .batch ["loss_mask" ][:, - response_length :]
256+ # Call compute_grpo_outcome_advantage with parameters matching its definition
249257 advantages , returns = core_algos .compute_grpo_outcome_advantage (
250- token_level_rewards = data .batch ["teacher_token_level_rewards " ],
258+ token_level_rewards = data .batch ["token_level_rewards " ],
251259 response_mask = grpo_calculation_mask ,
252260 index = data .non_tensor_batch ["uid" ],
253261 norm_adv_by_std_in_grpo = norm_adv_by_std_in_grpo ,
254- compute_teacher = True ,
255262 )
256- data .batch ["teacher_advantages " ] = advantages
257- data .batch ["teacher_returns " ] = returns
263+ data .batch ["advantages " ] = advantages
264+ data .batch ["returns " ] = returns
258265 else :
259266 # handle all other adv estimator type other than GAE and GRPO
260267 adv_estimator_fn = core_algos .get_adv_estimator_fn (adv_estimator )
@@ -330,8 +337,6 @@ def __init__(
330337 self .role_worker_mapping = role_worker_mapping
331338 self .resource_pool_manager = resource_pool_manager
332339 self .use_reference_policy = Role .RefPolicy in role_worker_mapping
333- # NOTE: no reference policy, only teacher sft
334- self .use_reference_policy = False
335340 self .use_rm = Role .RewardModel in role_worker_mapping
336341 self .ray_worker_group_cls = ray_worker_group_cls
337342 self .device_name = device_name
@@ -359,8 +364,8 @@ def __init__(
359364 self .use_critic = False
360365 else :
361366 raise NotImplementedError
362- # NOTE: no critic, only teacher sft
363- self .use_critic = False
367+ # NOTE: we hack critic as reward model. so always use critic
368+ self .use_critic = True
364369
365370 self ._validate_config ()
366371 self ._create_dataloader (train_dataset , val_dataset , collate_fn , train_sampler )
@@ -550,23 +555,18 @@ def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampl
550555 except Exception as e :
551556 print (f"Warning: Could not set total_training_steps in config. Structure missing? Error: { e } " )
552557
553- def _dump_generations (self , inputs , outputs , scores , reward_extra_infos_dict , dump_path ):
558+ def _dump_generations (self , sample_inputs , sample_outputs , teacher_outputs , dump_path ):
554559 """Dump rollout/validation samples as JSONL."""
555560 os .makedirs (dump_path , exist_ok = True )
556- filename = os .path .join (dump_path , f"{ self . global_steps } .jsonl" )
561+ filename = os .path .join (dump_path , f"generation_results .jsonl" )
557562
558- n = len (inputs )
563+ n = len (sample_inputs )
559564 base_data = {
560- "input" : inputs ,
561- "output" : outputs ,
562- "score" : scores ,
563- "step" : [self .global_steps ] * n ,
565+ "input" : sample_inputs ,
566+ "output" : sample_outputs ,
567+ "teacher_output" : teacher_outputs ,
564568 }
565569
566- for k , v in reward_extra_infos_dict .items ():
567- if len (v ) == n :
568- base_data [k ] = v
569-
570570 lines = []
571571 for i in range (n ):
572572 entry = {k : v [i ] for k , v in base_data .items ()}
@@ -691,7 +691,7 @@ def safe_rouge_score(ref, cand):
691691
692692 reward_extra_infos_dict ["reward" ].extend (scores )
693693 print (f"len reward_extra_infos_dict['reward']: { len (reward_extra_infos_dict ['reward' ])} " )
694-
694+
695695 data_source_lst .append (test_batch .non_tensor_batch .get ("data_source" , ["unknown" ] * len (scores )))
696696
697697 self ._maybe_log_val_generations (inputs = sample_inputs , outputs = sample_outputs , scores = sample_scores )
@@ -700,10 +700,9 @@ def safe_rouge_score(ref, cand):
700700 val_data_dir = self .config .trainer .get ("validation_data_dir" , None )
701701 if val_data_dir :
702702 self ._dump_generations (
703- inputs = sample_inputs ,
704- outputs = sample_outputs ,
705- scores = sample_scores ,
706- reward_extra_infos_dict = reward_extra_infos_dict ,
703+ sample_inputs = sample_inputs ,
704+ sample_outputs = sample_outputs ,
705+ teacher_outputs = teacher_outputs ,
707706 dump_path = val_data_dir ,
708707 )
709708
@@ -915,45 +914,6 @@ def _balance_batch(self, batch: DataProto, metrics, logging_prefix="global_seqle
915914 global_balance_stats = log_seqlen_unbalance (seqlen_list = global_seqlen_lst , partitions = global_partition_lst , prefix = logging_prefix )
916915 metrics .update (global_balance_stats )
917916
918- def _forward_batch_teacher_forcing_grpo (self , batch , teacher_repeat ):
919-
920- response_length = batch ["teacher_response" ].size (- 1 )
921-
922- with torch .autocast (device_type = self .device_name , dtype = torch .bfloat16 ):
923- input_ids = batch ["teacher_input_ids" ]
924- bsz , seqlen = input_ids .shape
925- attention_mask = batch ["teacher_attention_mask" ]
926- position_ids = batch ["teacher_position_ids" ]
927-
928- values = torch .zeros ((bsz , response_length ), device = input_ids .device )
929- response_mask = attention_mask [:, - response_length :]
930- response_lengths = response_mask .sum (dim = 1 ).long ()
931- last_token_indices = response_lengths - 1
932- for i in range (0 , bsz , teacher_repeat ):
933- for j in range (teacher_repeat ):
934- values [i + j , last_token_indices [i + j ]] = float (j )
935- return values
936-
937- def compute_teacher_values (self , data : DataProto ):
938- compute_teacher = data .meta_info ["compute_teacher" ]
939- if compute_teacher :
940- select_keys = ["teacher_response" , "teacher_input_ids" , "teacher_attention_mask" , "teacher_position_ids" ]
941- else :
942- select_keys = ["responses" , "input_ids" , "attention_mask" , "position_ids" ]
943- batch = data .select (batch_keys = select_keys ).batch
944-
945- # teacher forcing for GRPO
946- if compute_teacher :
947- teacher_repeat = data .meta_info ["teacher_repeat" ]
948- uids = data .non_tensor_batch ["uid" ]
949- for i in range (0 , len (uids ), teacher_repeat ):
950- assert all (uids [j ] == uids [i ] for j in range (i , i + teacher_repeat )), f"uids are not the same for a teacher group: { uids [i :i + teacher_repeat ]} "
951- return DataProto .from_dict (
952- tensors = {
953- "teacher_values" : self ._forward_batch_teacher_forcing_grpo (batch , teacher_repeat = teacher_repeat )
954- }
955- )
956-
957917 def fit (self ):
958918 """
959919 The training loop of PPO.
@@ -1009,7 +969,7 @@ def fit(self):
1009969 metrics = {}
1010970 timing_raw = {}
1011971 batch : DataProto = DataProto .from_single_dict (batch_dict )
1012-
972+
1013973 # pop those keys for generation
1014974 batch_keys_to_pop = ["input_ids" , "attention_mask" , "position_ids" , "teacher_response" ]
1015975 non_tensor_batch_keys_to_pop = ["raw_prompt_ids" ]
@@ -1061,7 +1021,7 @@ def fit(self):
10611021 batch = batch .repeat (repeat_times = self .config .actor_rollout_ref .rollout .n , interleave = True )
10621022 batch = batch .union (gen_batch_output )
10631023
1064- batch .batch ["teacher_response_mask " ] = compute_response_mask (batch , compute_teacher = True )
1024+ batch .batch ["response_mask " ] = compute_response_mask (batch )
10651025 # Balance the number of valid tokens across DP ranks.
10661026 # NOTE: This usually changes the order of data in the `batch`,
10671027 # which won't affect the advantage calculation (since it's based on uid),
@@ -1071,10 +1031,102 @@ def fit(self):
10711031 # self._balance_batch(batch, metrics=metrics)
10721032
10731033 # compute global_valid tokens
1074- batch .meta_info ["global_token_num" ] = torch .sum (batch .batch ["teacher_attention_mask" ], dim = - 1 ).tolist ()
1034+ batch .meta_info ["global_token_num" ] = torch .sum (batch .batch ["attention_mask" ], dim = - 1 ).tolist ()
1035+
1036+ # recompute old_log_probs
1037+ with marked_timer ("old_log_prob" , timing_raw , color = "blue" ):
1038+ batch .meta_info ["compute_teacher" ] = False
1039+ old_log_prob = self .actor_rollout_wg .compute_log_prob (batch )
1040+
1041+ entropys = old_log_prob .batch ["entropys" ]
1042+ response_masks = batch .batch ["response_mask" ]
1043+ loss_agg_mode = self .config .actor_rollout_ref .actor .loss_agg_mode
1044+ entropy_agg = agg_loss (loss_mat = entropys , loss_mask = response_masks , loss_agg_mode = loss_agg_mode )
1045+ old_log_prob_metrics = {"actor/entropy" : entropy_agg .detach ().item ()}
1046+ metrics .update (old_log_prob_metrics )
1047+ old_log_prob .batch .pop ("entropys" )
1048+ batch = batch .union (old_log_prob )
1049+
1050+ if "rollout_log_probs" in batch .batch .keys ():
1051+ # TODO: we may want to add diff of probs too.
1052+ rollout_old_log_probs = batch .batch ["rollout_log_probs" ]
1053+ actor_old_log_probs = batch .batch ["old_log_probs" ]
1054+ attention_mask = batch .batch ["attention_mask" ]
1055+ responses = batch .batch ["responses" ]
1056+ response_length = responses .size (1 )
1057+ response_mask = attention_mask [:, - response_length :]
1058+
1059+ rollout_probs = torch .exp (rollout_old_log_probs )
1060+ actor_probs = torch .exp (actor_old_log_probs )
1061+ rollout_probs_diff = torch .abs (rollout_probs - actor_probs )
1062+ rollout_probs_diff = torch .masked_select (rollout_probs_diff , response_mask .bool ())
1063+ rollout_probs_diff_max = torch .max (rollout_probs_diff )
1064+ rollout_probs_diff_mean = torch .mean (rollout_probs_diff )
1065+ rollout_probs_diff_std = torch .std (rollout_probs_diff )
1066+ metrics .update (
1067+ {
1068+ "training/rollout_probs_diff_max" : rollout_probs_diff_max .detach ().item (),
1069+ "training/rollout_probs_diff_mean" : rollout_probs_diff_mean .detach ().item (),
1070+ "training/rollout_probs_diff_std" : rollout_probs_diff_std .detach ().item (),
1071+ }
1072+ )
1073+
1074+ if self .use_reference_policy :
1075+ # compute reference log_prob
1076+ with marked_timer ("ref" , timing_raw , color = "olive" ):
1077+ if not self .ref_in_actor :
1078+ ref_log_prob = self .ref_policy_wg .compute_ref_log_prob (batch )
1079+ else :
1080+ ref_log_prob = self .actor_rollout_wg .compute_ref_log_prob (batch )
1081+ batch = batch .union (ref_log_prob )
1082+
1083+ # NOTE: we use critic to calculate score of student here
1084+ with marked_timer ("reward" , timing_raw , color = "yellow" ):
1085+ future_reward = None
1086+ reward_extra_infos_dict = {}
1087+ batch .meta_info ["compute_teacher" ] = False
1088+ values = self .critic_wg .compute_values (batch )
1089+ batch = batch .union (values )
1090+ reward_tensor = batch .batch ["values" ]
1091+
1092+ with marked_timer ("adv" , timing_raw , color = "brown" ):
1093+ # we combine with rule-based rm
1094+ reward_extra_infos_dict : dict [str , list ]
1095+ if self .config .reward_model .launch_reward_fn_async :
1096+ reward_tensor , reward_extra_infos_dict = ray .get (future_reward )
1097+ batch .batch ["token_level_scores" ] = reward_tensor
1098+
1099+ if reward_extra_infos_dict :
1100+ batch .non_tensor_batch .update ({k : np .array (v ) for k , v in reward_extra_infos_dict .items ()})
1101+
1102+ # compute rewards. apply_kl_penalty if available
1103+ if self .config .algorithm .use_kl_in_reward :
1104+ batch , kl_metrics = apply_kl_penalty (batch , kl_ctrl = self .kl_ctrl_in_reward , kl_penalty = self .config .algorithm .kl_penalty )
1105+ metrics .update (kl_metrics )
1106+ else :
1107+ batch .batch ["token_level_rewards" ] = batch .batch ["token_level_scores" ]
1108+
1109+ # compute advantages, executed on the driver process
1110+ norm_adv_by_std_in_grpo = self .config .algorithm .get ("norm_adv_by_std_in_grpo" , True ) # GRPO adv normalization factor
1111+
1112+ batch = compute_advantage (
1113+ batch ,
1114+ adv_estimator = self .config .algorithm .adv_estimator ,
1115+ gamma = self .config .algorithm .gamma ,
1116+ lam = self .config .algorithm .lam ,
1117+ num_repeat = self .config .actor_rollout_ref .rollout .n ,
1118+ norm_adv_by_std_in_grpo = norm_adv_by_std_in_grpo ,
1119+ multi_turn = self .config .actor_rollout_ref .rollout .multi_turn .enable ,
1120+ config = self .config .algorithm ,
1121+ )
1122+
1123+ # update critic
1124+ if self .use_critic :
1125+ with marked_timer ("update_critic" , timing_raw , color = "pink" ):
1126+ critic_output = self .critic_wg .update_critic (batch )
1127+ critic_output_metrics = reduce_metrics (critic_output .meta_info ["metrics" ])
1128+ metrics .update (critic_output_metrics )
10751129
1076- batch .meta_info ["temperature" ] = self .config .actor_rollout_ref .rollout .temperature
1077-
10781130 # implement critic warmup
10791131 if self .config .trainer .critic_warmup <= self .global_steps :
10801132 # update actor
@@ -1120,6 +1172,7 @@ def fit(self):
11201172 }
11211173 )
11221174 # collect metrics
1175+ metrics .update (compute_data_metrics (batch = batch , use_critic = self .use_critic ))
11231176 metrics .update (compute_timing_metrics (batch = batch , timing_raw = timing_raw ))
11241177 # TODO: implement actual tflpo and theoretical tflpo
11251178 n_gpus = self .resource_pool_manager .get_n_gpus ()
0 commit comments