Skip to content

Commit 3297cd1

Browse files
committed
Fix a typing issue
Signed-off-by: Olivier Delalleau <507137+odelalleau@users.noreply.github.com>
1 parent 4733717 commit 3297cd1

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

nemo_rl/algorithms/loss_functions.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -568,7 +568,7 @@ def __init__(self, cfg: DPOLossConfig):
568568

569569
self.loss_type = LossType.SEQUENCE_LEVEL
570570

571-
def _preference_loss(
571+
def _dpo_loss(
572572
self,
573573
next_token_logits: Tensor,
574574
data: BatchedDataDict[DPOLossDataDict],
@@ -619,7 +619,7 @@ def _preference_loss(
619619
if self.preference_average_log_probs:
620620
rewards = rewards / token_mask.sum(-1).clamp(min=1)
621621

622-
return super()._preference_loss(
622+
return self._preference_loss(
623623
rewards, sample_mask, global_valid_seqs, self.reference_policy_kl_penalty
624624
)
625625

@@ -661,7 +661,7 @@ def __call__(
661661
accuracy,
662662
rewards_chosen_mean,
663663
rewards_rejected_mean,
664-
) = self._preference_loss(
664+
) = self._dpo_loss(
665665
next_token_logits,
666666
data,
667667
global_valid_seqs,

0 commit comments

Comments
 (0)