Skip to content

Commit e39e564

Browse files
committed
gad train
1 parent 0c38dd2 commit e39e564

File tree

6 files changed

+70
-127
lines changed

6 files changed

+70
-127
lines changed

verl/trainer/ppo/core_algos.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,6 @@ def compute_grpo_outcome_advantage(
205205
index: np.ndarray,
206206
epsilon: float = 1e-6,
207207
norm_adv_by_std_in_grpo: str = True,
208-
compute_teacher: bool = False,
209208
):
210209
"""
211210
Compute advantage for GRPO, operating only on Outcome reward
@@ -230,7 +229,6 @@ def compute_grpo_outcome_advantage(
230229
scores = token_level_rewards.sum(dim=-1)
231230

232231
id2score = defaultdict(list)
233-
id2adv = defaultdict(list)
234232
id2mean = {}
235233
id2std = {}
236234

@@ -252,12 +250,6 @@ def compute_grpo_outcome_advantage(
252250
scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon)
253251
else:
254252
scores[i] = scores[i] - id2mean[index[i]]
255-
id2adv[index[i]].append(scores[i])
256-
257-
if compute_teacher:
258-
for i in range(bsz):
259-
scores[i] = max(id2adv[index[i]])
260-
261253
scores = scores.unsqueeze(-1) * response_mask
262254

263255
return scores, scores
@@ -566,16 +558,6 @@ def agg_loss(loss_mat: torch.Tensor, loss_mask: torch.Tensor, loss_agg_mode: str
566558
return loss
567559

568560

569-
def compute_sft_loss(
570-
log_prob,
571-
response_mask,
572-
loss_agg_mode: str = "token-mean",
573-
):
574-
pg_loss = agg_loss(loss_mat=-log_prob, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
575-
576-
return pg_loss
577-
578-
579561
def compute_policy_loss(
580562
old_log_prob,
581563
log_prob,

verl/trainer/ppo/ray_trainer.py

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController,
180180
return data, metrics
181181

182182

183-
def compute_response_mask(data: DataProto, compute_teacher=False):
183+
def compute_response_mask(data: DataProto):
184184
"""Compute the attention mask for the response part of the sequence.
185185
186186
This function extracts the portion of the attention mask that corresponds to the model's response,
@@ -192,16 +192,10 @@ def compute_response_mask(data: DataProto, compute_teacher=False):
192192
Returns:
193193
torch.Tensor: The attention mask for the response tokens.
194194
"""
195-
if compute_teacher:
196-
responses = data.batch["teacher_response"]
197-
response_length = responses.size(1)
198-
attention_mask = data.batch["teacher_attention_mask"]
199-
return attention_mask[:, -response_length:]
200-
else:
201-
responses = data.batch["responses"]
202-
response_length = responses.size(1)
203-
attention_mask = data.batch["attention_mask"]
204-
return attention_mask[:, -response_length:]
195+
responses = data.batch["responses"]
196+
response_length = responses.size(1)
197+
attention_mask = data.batch["attention_mask"]
198+
return attention_mask[:, -response_length:]
205199

206200

207201
def compute_advantage(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_repeat=1, multi_turn=False, norm_adv_by_std_in_grpo=True, config=None):
@@ -225,7 +219,7 @@ def compute_advantage(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_re
225219
"""
226220
# Back-compatible with trainers that do not compute response mask in fit
227221
if "response_mask" not in data.batch.keys():
228-
data.batch["response_mask"] = compute_response_mask(data, compute_teacher=False)
222+
data.batch["response_mask"] = compute_response_mask(data)
229223
# prepare response group
230224
if adv_estimator == AdvantageEstimator.GAE:
231225
# Compute advantages and returns using Generalized Advantage Estimation (GAE)
@@ -555,18 +549,23 @@ def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampl
555549
except Exception as e:
556550
print(f"Warning: Could not set total_training_steps in config. Structure missing? Error: {e}")
557551

558-
def _dump_generations(self, sample_inputs, sample_outputs, teacher_outputs, dump_path):
552+
def _dump_generations(self, inputs, outputs, scores, reward_extra_infos_dict, dump_path):
559553
"""Dump rollout/validation samples as JSONL."""
560554
os.makedirs(dump_path, exist_ok=True)
561-
filename = os.path.join(dump_path, f"generation_results.jsonl")
555+
filename = os.path.join(dump_path, f"{self.global_steps}.jsonl")
562556

563-
n = len(sample_inputs)
557+
n = len(inputs)
564558
base_data = {
565-
"input": sample_inputs,
566-
"output": sample_outputs,
567-
"teacher_output": teacher_outputs,
559+
"input": inputs,
560+
"output": outputs,
561+
"score": scores,
562+
"step": [self.global_steps] * n,
568563
}
569564

565+
for k, v in reward_extra_infos_dict.items():
566+
if len(v) == n:
567+
base_data[k] = v
568+
570569
lines = []
571570
for i in range(n):
572571
entry = {k: v[i] for k, v in base_data.items()}
@@ -700,9 +699,10 @@ def safe_rouge_score(ref, cand):
700699
val_data_dir = self.config.trainer.get("validation_data_dir", None)
701700
if val_data_dir:
702701
self._dump_generations(
703-
sample_inputs=sample_inputs,
704-
sample_outputs=sample_outputs,
705-
teacher_outputs=teacher_outputs,
702+
inputs=sample_inputs,
703+
outputs=sample_outputs,
704+
scores=sample_scores,
705+
reward_extra_infos_dict=reward_extra_infos_dict,
706706
dump_path=val_data_dir,
707707
)
708708

@@ -886,11 +886,12 @@ def _load_checkpoint(self):
886886

887887
actor_path = os.path.join(global_step_folder, "actor")
888888
critic_path = os.path.join(global_step_folder, "critic")
889+
# NOTE: have directly loaded from actor_rollout_ref.model.path and critic.model.path
889890
# load actor
890-
self.actor_rollout_wg.load_checkpoint(actor_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load)
891+
# self.actor_rollout_wg.load_checkpoint(actor_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load)
891892
# load critic
892-
if self.use_critic:
893-
self.critic_wg.load_checkpoint(critic_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load)
893+
# if self.use_critic and os.path.exists(critic_path):
894+
# self.critic_wg.load_checkpoint(critic_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load)
894895

895896
# load dataloader,
896897
# TODO: from remote not implemented yet
@@ -1035,9 +1036,7 @@ def fit(self):
10351036

10361037
# recompute old_log_probs
10371038
with marked_timer("old_log_prob", timing_raw, color="blue"):
1038-
batch.meta_info["compute_teacher"] = False
10391039
old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)
1040-
10411040
entropys = old_log_prob.batch["entropys"]
10421041
response_masks = batch.batch["response_mask"]
10431042
loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode
@@ -1084,18 +1083,18 @@ def fit(self):
10841083
with marked_timer("reward", timing_raw, color="yellow"):
10851084
future_reward = None
10861085
reward_extra_infos_dict = {}
1087-
batch.meta_info["compute_teacher"] = False
10881086
values = self.critic_wg.compute_values(batch)
10891087
batch = batch.union(values)
10901088
reward_tensor = batch.batch["values"]
1089+
# reward_tensor: (bsz, response_length)
10911090

10921091
with marked_timer("adv", timing_raw, color="brown"):
10931092
# we combine with rule-based rm
10941093
reward_extra_infos_dict: dict[str, list]
10951094
if self.config.reward_model.launch_reward_fn_async:
10961095
reward_tensor, reward_extra_infos_dict = ray.get(future_reward)
10971096
batch.batch["token_level_scores"] = reward_tensor
1098-
1097+
10991098
if reward_extra_infos_dict:
11001099
batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()})
11011100

@@ -1107,6 +1106,7 @@ def fit(self):
11071106
batch.batch["token_level_rewards"] = batch.batch["token_level_scores"]
11081107

11091108
# compute advantages, executed on the driver process
1109+
11101110
norm_adv_by_std_in_grpo = self.config.algorithm.get("norm_adv_by_std_in_grpo", True) # GRPO adv normalization factor
11111111

11121112
batch = compute_advantage(

verl/utils/dataset/rl_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ def __getitem__(self, item):
267267

268268
if teacher_response is not None:
269269
teacher_response = self.tokenizer(teacher_response, return_tensors="pt", add_special_tokens=False)
270-
270+
271271
if not self.processor_type == "MiniCPMVImageProcessor":
272272
input_ids, attention_mask = verl_F.postprocess_data(
273273
input_ids=input_ids,

verl/workers/actor/dp_actor.py

Lines changed: 35 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828

2929
import verl.utils.torch_functional as verl_F
3030
from verl import DataProto
31-
from verl.trainer.ppo.core_algos import agg_loss, compute_policy_loss, compute_sft_loss, get_policy_loss_fn, kl_penalty
31+
from verl.trainer.ppo.core_algos import agg_loss, compute_policy_loss, get_policy_loss_fn, kl_penalty
3232
from verl.utils.debug import GPUMemoryLogger
3333
from verl.utils.device import get_device_id, get_device_name, is_cuda_available, is_npu_available
3434
from verl.utils.fsdp_utils import FSDPModule, fsdp2_clip_grad_norm_
@@ -79,16 +79,13 @@ def __init__(self, config, actor_module: nn.Module, actor_optimizer: torch.optim
7979
)
8080
self.device_name = get_device_name()
8181

82-
def _forward_micro_batch(self, micro_batch, temperature, compute_teacher, calculate_entropy=False) -> Tuple[torch.Tensor, torch.Tensor]:
82+
def _forward_micro_batch(self, micro_batch, temperature, calculate_entropy=False) -> Tuple[torch.Tensor, torch.Tensor]:
8383
"""
8484
Returns:
8585
entropy: # (bs, response_len)
8686
log_probs: # (bs, response_len)
8787
"""
88-
if compute_teacher:
89-
response_length = micro_batch["teacher_response"].size(-1)
90-
else:
91-
response_length = micro_batch["responses"].size(-1)
88+
response_length = micro_batch["responses"].size(-1)
9289
multi_modal_inputs = {}
9390
if "multi_modal_inputs" in micro_batch.keys():
9491
for key in micro_batch["multi_modal_inputs"][0].keys():
@@ -101,16 +98,10 @@ def _forward_micro_batch(self, micro_batch, temperature, compute_teacher, calcul
10198
multi_modal_inputs[key] = torch.cat([inputs[key] for inputs in micro_batch["multi_modal_inputs"]], dim=0)
10299

103100
with torch.autocast(device_type=self.device_name, dtype=torch.bfloat16):
104-
if compute_teacher:
105-
input_ids = micro_batch["teacher_input_ids"]
106-
batch_size, seqlen = input_ids.shape
107-
attention_mask = micro_batch["teacher_attention_mask"]
108-
position_ids = micro_batch["teacher_position_ids"]
109-
else:
110-
input_ids = micro_batch["input_ids"]
111-
batch_size, seqlen = input_ids.shape
112-
attention_mask = micro_batch["attention_mask"]
113-
position_ids = micro_batch["position_ids"]
101+
input_ids = micro_batch["input_ids"]
102+
batch_size, seqlen = input_ids.shape
103+
attention_mask = micro_batch["attention_mask"]
104+
position_ids = micro_batch["position_ids"]
114105
entropy = None
115106
if position_ids.dim() == 3: # qwen2vl mrope
116107
position_ids = position_ids.transpose(0, 1) # (bsz, 3, seqlen) -> (3, bsz, seqlen)
@@ -315,7 +306,6 @@ def compute_log_prob(self, data: DataProto, calculate_entropy=False) -> torch.Te
315306
Returns:
316307
torch.Tensor: the log_prob tensor
317308
"""
318-
compute_teacher = data.meta_info["compute_teacher"]
319309
# set to eval
320310
self.actor_module.eval()
321311

@@ -324,10 +314,7 @@ def compute_log_prob(self, data: DataProto, calculate_entropy=False) -> torch.Te
324314
use_dynamic_bsz = data.meta_info["use_dynamic_bsz"]
325315

326316
def _get_micro_batches(data: DataProto) -> Tuple[list, list | None]:
327-
if compute_teacher:
328-
select_keys = ["teacher_response", "teacher_input_ids", "teacher_attention_mask", "teacher_position_ids"]
329-
else:
330-
select_keys = ["responses", "input_ids", "attention_mask", "position_ids"]
317+
select_keys = ["responses", "input_ids", "attention_mask", "position_ids"]
331318
batch = data.select(batch_keys=select_keys).batch
332319
has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch
333320

@@ -352,7 +339,7 @@ def _get_micro_batches(data: DataProto) -> Tuple[list, list | None]:
352339
return micro_batches_dp, None
353340
elif use_dynamic_bsz:
354341
max_token_len = data.meta_info["max_token_len"] * self.ulysses_sequence_parallel_size
355-
micro_batches, indices = rearrange_micro_batches(batch=batch, max_token_len=max_token_len, compute_teacher=compute_teacher)
342+
micro_batches, indices = rearrange_micro_batches(batch=batch, max_token_len=max_token_len)
356343
return micro_batches, indices
357344
else:
358345
micro_batches = batch.split(micro_batch_size)
@@ -366,7 +353,7 @@ def _get_micro_batches(data: DataProto) -> Tuple[list, list | None]:
366353
if isinstance(micro_batch, DataProto):
367354
micro_batch = {**micro_batch.batch, **micro_batch.non_tensor_batch}
368355
with torch.no_grad():
369-
entropy, log_probs = self._forward_micro_batch(micro_batch, compute_teacher=compute_teacher, temperature=temperature, calculate_entropy=calculate_entropy)
356+
entropy, log_probs = self._forward_micro_batch(micro_batch, temperature=temperature, calculate_entropy=calculate_entropy)
370357
log_probs_lst.append(log_probs)
371358
if calculate_entropy:
372359
entropy_lst.append(entropy)
@@ -387,17 +374,13 @@ def _get_micro_batches(data: DataProto) -> Tuple[list, list | None]:
387374

388375
@GPUMemoryLogger(role="dp actor", logger=logger)
389376
def update_policy(self, data: DataProto):
390-
391377
# make sure we are in training mode
392378
self.actor_module.train()
393379

394380
temperature = data.meta_info["temperature"] # temperature must be in the data.meta_info to avoid silent error
395381
multi_turn = data.meta_info.get("multi_turn", False)
396382

397-
select_keys = [
398-
"responses", "input_ids", "attention_mask", "position_ids", "old_log_probs", "advantages",
399-
"teacher_response", "teacher_input_ids", "teacher_attention_mask", "teacher_position_ids"
400-
]
383+
select_keys = ["responses", "input_ids", "attention_mask", "position_ids", "old_log_probs", "advantages"]
401384
if multi_turn:
402385
select_keys.append("loss_mask")
403386
if self.config.use_kl_loss:
@@ -439,7 +422,7 @@ def update_policy(self, data: DataProto):
439422
micro_batches = data.select(select_keys, non_tensor_select_keys).chunk(num_micro_batches)
440423
elif self.config.use_dynamic_bsz:
441424
max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size
442-
micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len, compute_teacher=False)
425+
micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len)
443426
else:
444427
self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu
445428
# split batch into micro_batches
@@ -462,16 +445,12 @@ def update_policy(self, data: DataProto):
462445
else:
463446
data = data.to(get_device_id()) # actor device is cpu when using offload
464447
responses = data["responses"]
465-
teacher_response = data["teacher_response"]
466448
response_length = responses.size(1)
467-
teacher_response_length = teacher_response.size(1)
468449
attention_mask = data["attention_mask"]
469-
teacher_attention_mask = data["teacher_attention_mask"]
470450
if multi_turn:
471451
response_mask = data["loss_mask"][:, -response_length:]
472452
else:
473453
response_mask = attention_mask[:, -response_length:]
474-
teacher_response_mask = teacher_attention_mask[:, -teacher_response_length:]
475454

476455
old_log_prob = data["old_log_probs"]
477456
advantages = data["advantages"]
@@ -487,17 +466,22 @@ def update_policy(self, data: DataProto):
487466
calculate_entropy = False
488467
if entropy_coeff != 0:
489468
calculate_entropy = True
490-
teacher_entropy, teacher_log_prob = self._forward_micro_batch(micro_batch=data, compute_teacher=True, temperature=temperature, calculate_entropy=calculate_entropy)
469+
entropy, log_prob = self._forward_micro_batch(micro_batch=data, temperature=temperature, calculate_entropy=calculate_entropy)
491470

492471
loss_mode = self.config.policy_loss.get("loss_mode", "vanilla")
493472

494473
if self.config.policy_loss.loss_mode == "vanilla":
495-
teacher_pg_loss = compute_sft_loss(
496-
log_prob=teacher_log_prob,
497-
response_mask=teacher_response_mask,
474+
pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = compute_policy_loss(
475+
old_log_prob=old_log_prob,
476+
log_prob=log_prob,
477+
advantages=advantages,
478+
response_mask=response_mask,
479+
cliprange=clip_ratio,
480+
cliprange_low=clip_ratio_low,
481+
cliprange_high=clip_ratio_high,
482+
clip_ratio_c=clip_ratio_c,
498483
loss_agg_mode=loss_agg_mode,
499484
)
500-
pg_loss = teacher_pg_loss
501485
else:
502486
policy_loss_fn = get_policy_loss_fn(loss_mode)
503487
pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = policy_loss_fn(old_log_prob, log_prob, advantages, response_mask, loss_agg_mode, self.config)
@@ -510,6 +494,16 @@ def update_policy(self, data: DataProto):
510494
else:
511495
policy_loss = pg_loss
512496

497+
if self.config.use_kl_loss:
498+
ref_log_prob = data["ref_log_prob"]
499+
# compute kl loss
500+
kld = kl_penalty(logprob=log_prob, ref_logprob=ref_log_prob, kl_penalty=self.config.kl_loss_type)
501+
kl_loss = agg_loss(loss_mat=kld, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
502+
503+
policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef
504+
metrics["actor/kl_loss"] = kl_loss.detach().item()
505+
metrics["actor/kl_coef"] = self.config.kl_loss_coef
506+
513507
if self.config.use_dynamic_bsz:
514508
# relative to the dynamic bsz
515509
loss = policy_loss * (len(data) / self.config.ppo_mini_batch_size)
@@ -519,7 +513,9 @@ def update_policy(self, data: DataProto):
519513

520514
data = {
521515
"actor/pg_loss": pg_loss.detach().item(),
522-
"actor/teacher_pg_loss": teacher_pg_loss.detach().item(),
516+
"actor/pg_clipfrac": pg_clipfrac.detach().item(),
517+
"actor/ppo_kl": ppo_kl.detach().item(),
518+
"actor/pg_clipfrac_lower": pg_clipfrac_lower.detach().item(),
523519
}
524520
append_to_dict(metrics, data)
525521

0 commit comments

Comments
 (0)