@@ -355,9 +355,6 @@ def __call__(
355355 dp_gang = self ._gangs .dp ,
356356 vllm_model = self ._vllm_model ,
357357 )
358- # if self._gangs.root.rank == 0:
359- # breakpoint()
360- # self._gangs.root.barrier()
361358 if self ._config .clip_rollout_after_think is not None :
362359 prompt_batch .meta_info ["suffix" ] = [
363360 self ._rollout_tokenizer .decode (
@@ -417,6 +414,27 @@ def __call__(
417414 grpo_input_batch_seqs , grpo_input_batch_seqs_layout
418415 )
419416
417+ # if self._gangs.root.rank == 0:
418+ # breakpoint()
419+ # self._gangs.root.barrier()
420+
421+ # FIXME NLL loss only works for batch_size = 1 for now
422+ # suffix_text = prompt_batch.meta_info.get("suffix")[0]
423+ # targets = (
424+ # torch.Tensor(self._rollout_tokenizer.encode(suffix_text))
425+ # .repeat(grpo_input_batch_seqs.size(0), 1)
426+ # .to(grpo_input_batch_seqs.device)
427+ # )
428+ # target_mask = torch.ones_like(targets).to(targets.device).float()
429+
430+ # nll_loss, chosen_logits = self._model.module(
431+ # grpo_input_batch_seqs,
432+ # grpo_input_batch_seqs_layout,
433+ # targets=targets,
434+ # target_mask=target_mask,
435+ # return_logits=True,
436+ # )
437+
420438 model_logps = self ._gather_lprobs (grpo_model_logits , grpo_target_batch )
421439 rollout_window = self ._rollout_bag .get_rollout_start_end (
422440 self ._config .loss_config .forward_group_size
0 commit comments