Skip to content

Commit 0c38dd2

Browse files
committed
warmup
1 parent ec44c40 commit 0c38dd2

File tree

7 files changed

+215
-170
lines changed

7 files changed

+215
-170
lines changed

verl/trainer/ppo/core_algos.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,16 @@ def agg_loss(loss_mat: torch.Tensor, loss_mask: torch.Tensor, loss_agg_mode: str
566566
return loss
567567

568568

569+
def compute_sft_loss(
570+
log_prob,
571+
response_mask,
572+
loss_agg_mode: str = "token-mean",
573+
):
574+
pg_loss = agg_loss(loss_mat=-log_prob, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
575+
576+
return pg_loss
577+
578+
569579
def compute_policy_loss(
570580
old_log_prob,
571581
log_prob,
@@ -632,16 +642,6 @@ def compute_policy_loss(
632642
return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower
633643

634644

635-
def compute_sft_loss(
636-
log_prob,
637-
response_mask,
638-
loss_agg_mode: str = "token-mean",
639-
):
640-
pg_loss = agg_loss(loss_mat=-log_prob, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
641-
642-
return pg_loss
643-
644-
645645
@register_policy_loss("clip_cov")
646646
def compute_policy_loss_clip_cov(
647647
old_log_prob,

verl/trainer/ppo/metric_utils.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -62,18 +62,18 @@ def _compute_response_info(batch: DataProto) -> Dict[str, Any]:
6262
- prompt_length: Tensor of prompt lengths for each item in the batch
6363
- response_length: Tensor of response lengths for each item in the batch
6464
"""
65-
response_length = batch.batch["teacher_response"].shape[-1]
65+
response_length = batch.batch["responses"].shape[-1]
6666

67-
prompt_mask = batch.batch["teacher_attention_mask"][:, :-response_length]
68-
response_mask = batch.batch["teacher_attention_mask"][:, -response_length:]
67+
prompt_mask = batch.batch["attention_mask"][:, :-response_length]
68+
response_mask = batch.batch["attention_mask"][:, -response_length:]
6969

7070
prompt_length = prompt_mask.sum(-1).float()
7171
response_length = response_mask.sum(-1).float() # (batch_size,)
7272

7373
return dict(
74-
teacher_response_mask=response_mask,
74+
response_mask=response_mask,
7575
prompt_length=prompt_length,
76-
teacher_response_length=response_length,
76+
response_length=response_length,
7777
)
7878

7979

@@ -100,22 +100,22 @@ def compute_data_metrics(batch: DataProto, use_critic: bool = True) -> Dict[str,
100100
- response_length/mean, max, min, clip_ratio: Statistics about response lengths
101101
- prompt_length/mean, max, min, clip_ratio: Statistics about prompt lengths
102102
"""
103-
sequence_score = batch.batch["teacher_token_level_scores"].sum(-1)
104-
sequence_reward = batch.batch["teacher_token_level_rewards"].sum(-1)
103+
sequence_score = batch.batch["token_level_scores"].sum(-1)
104+
sequence_reward = batch.batch["token_level_rewards"].sum(-1)
105105

106-
advantages = batch.batch["teacher_advantages"]
107-
returns = batch.batch["teacher_returns"]
106+
advantages = batch.batch["advantages"]
107+
returns = batch.batch["returns"]
108108

109-
max_response_length = batch.batch["teacher_response"].shape[-1]
109+
max_response_length = batch.batch["responses"].shape[-1]
110110

111-
prompt_mask = batch.batch["teacher_attention_mask"][:, :-max_response_length].bool()
112-
response_mask = batch.batch["teacher_attention_mask"][:, -max_response_length:].bool()
111+
prompt_mask = batch.batch["attention_mask"][:, :-max_response_length].bool()
112+
response_mask = batch.batch["attention_mask"][:, -max_response_length:].bool()
113113

114114
max_prompt_length = prompt_mask.size(-1)
115115

116116
response_info = _compute_response_info(batch)
117117
prompt_length = response_info["prompt_length"]
118-
response_length = response_info["teacher_response_length"]
118+
response_length = response_info["response_length"]
119119

120120
valid_adv = torch.masked_select(advantages, response_mask)
121121
valid_returns = torch.masked_select(returns, response_mask)
@@ -194,7 +194,7 @@ def compute_timing_metrics(batch: DataProto, timing_raw: Dict[str, float]) -> Di
194194
"""
195195
response_info = _compute_response_info(batch)
196196
num_prompt_tokens = torch.sum(response_info["prompt_length"]).item()
197-
num_response_tokens = torch.sum(response_info["teacher_response_length"]).item()
197+
num_response_tokens = torch.sum(response_info["response_length"]).item()
198198
num_overall_tokens = num_prompt_tokens + num_response_tokens
199199

200200
num_tokens_of_section = {

verl/trainer/ppo/ray_trainer.py

Lines changed: 124 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -224,8 +224,8 @@ def compute_advantage(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_re
224224
DataProto: The updated data with computed advantages and returns.
225225
"""
226226
# Back-compatible with trainers that do not compute response mask in fit
227-
if "teacher_response_mask" not in data.batch.keys():
228-
data.batch["teacher_response_mask"] = compute_response_mask(data, compute_teacher=True)
227+
if "response_mask" not in data.batch.keys():
228+
data.batch["response_mask"] = compute_response_mask(data, compute_teacher=False)
229229
# prepare response group
230230
if adv_estimator == AdvantageEstimator.GAE:
231231
# Compute advantages and returns using Generalized Advantage Estimation (GAE)
@@ -245,16 +245,23 @@ def compute_advantage(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_re
245245
config.get("pf_ppo_weight_pow", 2.0),
246246
)
247247
elif adv_estimator == AdvantageEstimator.GRPO:
248-
grpo_calculation_mask = data.batch["teacher_response_mask"]
248+
# Initialize the mask for GRPO calculation
249+
grpo_calculation_mask = data.batch["response_mask"]
250+
if multi_turn:
251+
# If multi-turn, replace the mask with the relevant part of loss_mask
252+
# Get length from the initial response mask
253+
response_length = grpo_calculation_mask.size(1)
254+
# This mask is the one intended for GRPO
255+
grpo_calculation_mask = data.batch["loss_mask"][:, -response_length:]
256+
# Call compute_grpo_outcome_advantage with parameters matching its definition
249257
advantages, returns = core_algos.compute_grpo_outcome_advantage(
250-
token_level_rewards=data.batch["teacher_token_level_rewards"],
258+
token_level_rewards=data.batch["token_level_rewards"],
251259
response_mask=grpo_calculation_mask,
252260
index=data.non_tensor_batch["uid"],
253261
norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,
254-
compute_teacher=True,
255262
)
256-
data.batch["teacher_advantages"] = advantages
257-
data.batch["teacher_returns"] = returns
263+
data.batch["advantages"] = advantages
264+
data.batch["returns"] = returns
258265
else:
259266
# handle all other adv estimator type other than GAE and GRPO
260267
adv_estimator_fn = core_algos.get_adv_estimator_fn(adv_estimator)
@@ -330,8 +337,6 @@ def __init__(
330337
self.role_worker_mapping = role_worker_mapping
331338
self.resource_pool_manager = resource_pool_manager
332339
self.use_reference_policy = Role.RefPolicy in role_worker_mapping
333-
# NOTE: no reference policy, only teacher sft
334-
self.use_reference_policy = False
335340
self.use_rm = Role.RewardModel in role_worker_mapping
336341
self.ray_worker_group_cls = ray_worker_group_cls
337342
self.device_name = device_name
@@ -359,8 +364,8 @@ def __init__(
359364
self.use_critic = False
360365
else:
361366
raise NotImplementedError
362-
# NOTE: no critic, only teacher sft
363-
self.use_critic = False
367+
# NOTE: we hack critic as reward model. so always use critic
368+
self.use_critic = True
364369

365370
self._validate_config()
366371
self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler)
@@ -550,23 +555,18 @@ def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampl
550555
except Exception as e:
551556
print(f"Warning: Could not set total_training_steps in config. Structure missing? Error: {e}")
552557

553-
def _dump_generations(self, inputs, outputs, scores, reward_extra_infos_dict, dump_path):
558+
def _dump_generations(self, sample_inputs, sample_outputs, teacher_outputs, dump_path):
554559
"""Dump rollout/validation samples as JSONL."""
555560
os.makedirs(dump_path, exist_ok=True)
556-
filename = os.path.join(dump_path, f"{self.global_steps}.jsonl")
561+
filename = os.path.join(dump_path, f"generation_results.jsonl")
557562

558-
n = len(inputs)
563+
n = len(sample_inputs)
559564
base_data = {
560-
"input": inputs,
561-
"output": outputs,
562-
"score": scores,
563-
"step": [self.global_steps] * n,
565+
"input": sample_inputs,
566+
"output": sample_outputs,
567+
"teacher_output": teacher_outputs,
564568
}
565569

566-
for k, v in reward_extra_infos_dict.items():
567-
if len(v) == n:
568-
base_data[k] = v
569-
570570
lines = []
571571
for i in range(n):
572572
entry = {k: v[i] for k, v in base_data.items()}
@@ -691,7 +691,7 @@ def safe_rouge_score(ref, cand):
691691

692692
reward_extra_infos_dict["reward"].extend(scores)
693693
print(f"len reward_extra_infos_dict['reward']: {len(reward_extra_infos_dict['reward'])}")
694-
694+
695695
data_source_lst.append(test_batch.non_tensor_batch.get("data_source", ["unknown"] * len(scores)))
696696

697697
self._maybe_log_val_generations(inputs=sample_inputs, outputs=sample_outputs, scores=sample_scores)
@@ -700,10 +700,9 @@ def safe_rouge_score(ref, cand):
700700
val_data_dir = self.config.trainer.get("validation_data_dir", None)
701701
if val_data_dir:
702702
self._dump_generations(
703-
inputs=sample_inputs,
704-
outputs=sample_outputs,
705-
scores=sample_scores,
706-
reward_extra_infos_dict=reward_extra_infos_dict,
703+
sample_inputs=sample_inputs,
704+
sample_outputs=sample_outputs,
705+
teacher_outputs=teacher_outputs,
707706
dump_path=val_data_dir,
708707
)
709708

@@ -915,45 +914,6 @@ def _balance_batch(self, batch: DataProto, metrics, logging_prefix="global_seqle
915914
global_balance_stats = log_seqlen_unbalance(seqlen_list=global_seqlen_lst, partitions=global_partition_lst, prefix=logging_prefix)
916915
metrics.update(global_balance_stats)
917916

918-
def _forward_batch_teacher_forcing_grpo(self, batch, teacher_repeat):
919-
920-
response_length = batch["teacher_response"].size(-1)
921-
922-
with torch.autocast(device_type=self.device_name, dtype=torch.bfloat16):
923-
input_ids = batch["teacher_input_ids"]
924-
bsz, seqlen = input_ids.shape
925-
attention_mask = batch["teacher_attention_mask"]
926-
position_ids = batch["teacher_position_ids"]
927-
928-
values = torch.zeros((bsz, response_length), device=input_ids.device)
929-
response_mask = attention_mask[:, -response_length:]
930-
response_lengths = response_mask.sum(dim=1).long()
931-
last_token_indices = response_lengths - 1
932-
for i in range(0, bsz, teacher_repeat):
933-
for j in range(teacher_repeat):
934-
values[i + j, last_token_indices[i + j]] = float(j)
935-
return values
936-
937-
def compute_teacher_values(self, data: DataProto):
938-
compute_teacher = data.meta_info["compute_teacher"]
939-
if compute_teacher:
940-
select_keys = ["teacher_response", "teacher_input_ids", "teacher_attention_mask", "teacher_position_ids"]
941-
else:
942-
select_keys = ["responses", "input_ids", "attention_mask", "position_ids"]
943-
batch = data.select(batch_keys=select_keys).batch
944-
945-
# teacher forcing for GRPO
946-
if compute_teacher:
947-
teacher_repeat = data.meta_info["teacher_repeat"]
948-
uids = data.non_tensor_batch["uid"]
949-
for i in range(0, len(uids), teacher_repeat):
950-
assert all(uids[j] == uids[i] for j in range(i, i + teacher_repeat)), f"uids are not the same for a teacher group: {uids[i:i+teacher_repeat]}"
951-
return DataProto.from_dict(
952-
tensors={
953-
"teacher_values": self._forward_batch_teacher_forcing_grpo(batch, teacher_repeat=teacher_repeat)
954-
}
955-
)
956-
957917
def fit(self):
958918
"""
959919
The training loop of PPO.
@@ -1009,7 +969,7 @@ def fit(self):
1009969
metrics = {}
1010970
timing_raw = {}
1011971
batch: DataProto = DataProto.from_single_dict(batch_dict)
1012-
972+
1013973
# pop those keys for generation
1014974
batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids", "teacher_response"]
1015975
non_tensor_batch_keys_to_pop = ["raw_prompt_ids"]
@@ -1061,7 +1021,7 @@ def fit(self):
10611021
batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
10621022
batch = batch.union(gen_batch_output)
10631023

1064-
batch.batch["teacher_response_mask"] = compute_response_mask(batch, compute_teacher=True)
1024+
batch.batch["response_mask"] = compute_response_mask(batch)
10651025
# Balance the number of valid tokens across DP ranks.
10661026
# NOTE: This usually changes the order of data in the `batch`,
10671027
# which won't affect the advantage calculation (since it's based on uid),
@@ -1071,10 +1031,102 @@ def fit(self):
10711031
# self._balance_batch(batch, metrics=metrics)
10721032

10731033
# compute global_valid tokens
1074-
batch.meta_info["global_token_num"] = torch.sum(batch.batch["teacher_attention_mask"], dim=-1).tolist()
1034+
batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist()
1035+
1036+
# recompute old_log_probs
1037+
with marked_timer("old_log_prob", timing_raw, color="blue"):
1038+
batch.meta_info["compute_teacher"] = False
1039+
old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)
1040+
1041+
entropys = old_log_prob.batch["entropys"]
1042+
response_masks = batch.batch["response_mask"]
1043+
loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode
1044+
entropy_agg = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode)
1045+
old_log_prob_metrics = {"actor/entropy": entropy_agg.detach().item()}
1046+
metrics.update(old_log_prob_metrics)
1047+
old_log_prob.batch.pop("entropys")
1048+
batch = batch.union(old_log_prob)
1049+
1050+
if "rollout_log_probs" in batch.batch.keys():
1051+
# TODO: we may want to add diff of probs too.
1052+
rollout_old_log_probs = batch.batch["rollout_log_probs"]
1053+
actor_old_log_probs = batch.batch["old_log_probs"]
1054+
attention_mask = batch.batch["attention_mask"]
1055+
responses = batch.batch["responses"]
1056+
response_length = responses.size(1)
1057+
response_mask = attention_mask[:, -response_length:]
1058+
1059+
rollout_probs = torch.exp(rollout_old_log_probs)
1060+
actor_probs = torch.exp(actor_old_log_probs)
1061+
rollout_probs_diff = torch.abs(rollout_probs - actor_probs)
1062+
rollout_probs_diff = torch.masked_select(rollout_probs_diff, response_mask.bool())
1063+
rollout_probs_diff_max = torch.max(rollout_probs_diff)
1064+
rollout_probs_diff_mean = torch.mean(rollout_probs_diff)
1065+
rollout_probs_diff_std = torch.std(rollout_probs_diff)
1066+
metrics.update(
1067+
{
1068+
"training/rollout_probs_diff_max": rollout_probs_diff_max.detach().item(),
1069+
"training/rollout_probs_diff_mean": rollout_probs_diff_mean.detach().item(),
1070+
"training/rollout_probs_diff_std": rollout_probs_diff_std.detach().item(),
1071+
}
1072+
)
1073+
1074+
if self.use_reference_policy:
1075+
# compute reference log_prob
1076+
with marked_timer("ref", timing_raw, color="olive"):
1077+
if not self.ref_in_actor:
1078+
ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
1079+
else:
1080+
ref_log_prob = self.actor_rollout_wg.compute_ref_log_prob(batch)
1081+
batch = batch.union(ref_log_prob)
1082+
1083+
# NOTE: we use critic to calculate score of student here
1084+
with marked_timer("reward", timing_raw, color="yellow"):
1085+
future_reward = None
1086+
reward_extra_infos_dict = {}
1087+
batch.meta_info["compute_teacher"] = False
1088+
values = self.critic_wg.compute_values(batch)
1089+
batch = batch.union(values)
1090+
reward_tensor = batch.batch["values"]
1091+
1092+
with marked_timer("adv", timing_raw, color="brown"):
1093+
# we combine with rule-based rm
1094+
reward_extra_infos_dict: dict[str, list]
1095+
if self.config.reward_model.launch_reward_fn_async:
1096+
reward_tensor, reward_extra_infos_dict = ray.get(future_reward)
1097+
batch.batch["token_level_scores"] = reward_tensor
1098+
1099+
if reward_extra_infos_dict:
1100+
batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()})
1101+
1102+
# compute rewards. apply_kl_penalty if available
1103+
if self.config.algorithm.use_kl_in_reward:
1104+
batch, kl_metrics = apply_kl_penalty(batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty)
1105+
metrics.update(kl_metrics)
1106+
else:
1107+
batch.batch["token_level_rewards"] = batch.batch["token_level_scores"]
1108+
1109+
# compute advantages, executed on the driver process
1110+
norm_adv_by_std_in_grpo = self.config.algorithm.get("norm_adv_by_std_in_grpo", True) # GRPO adv normalization factor
1111+
1112+
batch = compute_advantage(
1113+
batch,
1114+
adv_estimator=self.config.algorithm.adv_estimator,
1115+
gamma=self.config.algorithm.gamma,
1116+
lam=self.config.algorithm.lam,
1117+
num_repeat=self.config.actor_rollout_ref.rollout.n,
1118+
norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,
1119+
multi_turn=self.config.actor_rollout_ref.rollout.multi_turn.enable,
1120+
config=self.config.algorithm,
1121+
)
1122+
1123+
# update critic
1124+
if self.use_critic:
1125+
with marked_timer("update_critic", timing_raw, color="pink"):
1126+
critic_output = self.critic_wg.update_critic(batch)
1127+
critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"])
1128+
metrics.update(critic_output_metrics)
10751129

1076-
batch.meta_info["temperature"] = self.config.actor_rollout_ref.rollout.temperature
1077-
10781130
# implement critic warmup
10791131
if self.config.trainer.critic_warmup <= self.global_steps:
10801132
# update actor
@@ -1120,6 +1172,7 @@ def fit(self):
11201172
}
11211173
)
11221174
# collect metrics
1175+
metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))
11231176
metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))
11241177
# TODO: implement actual tflpo and theoretical tflpo
11251178
n_gpus = self.resource_pool_manager.get_n_gpus()

0 commit comments

Comments
 (0)