Skip to content

Commit a6013ce

Browse files
start nll loss
1 parent 6c6048c commit a6013ce

File tree

1 file changed

+21
-3
lines changed
  • src/fairseq2/recipes/lm/_online_finetune

1 file changed

+21
-3
lines changed

src/fairseq2/recipes/lm/_online_finetune/_grpo.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -355,9 +355,6 @@ def __call__(
355355
dp_gang=self._gangs.dp,
356356
vllm_model=self._vllm_model,
357357
)
358-
# if self._gangs.root.rank == 0:
359-
# breakpoint()
360-
# self._gangs.root.barrier()
361358
if self._config.clip_rollout_after_think is not None:
362359
prompt_batch.meta_info["suffix"] = [
363360
self._rollout_tokenizer.decode(
@@ -417,6 +414,27 @@ def __call__(
417414
grpo_input_batch_seqs, grpo_input_batch_seqs_layout
418415
)
419416

417+
# if self._gangs.root.rank == 0:
418+
# breakpoint()
419+
# self._gangs.root.barrier()
420+
421+
# FIXME NLL loss only works for batch_size = 1 for now
422+
# suffix_text = prompt_batch.meta_info.get("suffix")[0]
423+
# targets = (
424+
# torch.Tensor(self._rollout_tokenizer.encode(suffix_text))
425+
# .repeat(grpo_input_batch_seqs.size(0), 1)
426+
# .to(grpo_input_batch_seqs.device)
427+
# )
428+
# target_mask = torch.ones_like(targets).to(targets.device).float()
429+
430+
# nll_loss, chosen_logits = self._model.module(
431+
# grpo_input_batch_seqs,
432+
# grpo_input_batch_seqs_layout,
433+
# targets=targets,
434+
# target_mask=target_mask,
435+
# return_logits=True,
436+
# )
437+
420438
model_logps = self._gather_lprobs(grpo_model_logits, grpo_target_batch)
421439
rollout_window = self._rollout_bag.get_rollout_start_end(
422440
self._config.loss_config.forward_group_size

0 commit comments

Comments
 (0)