Skip to content

Commit b048a2d

Browse files
committed
fix dpo bug
1 parent 8555549 commit b048a2d

File tree

6 files changed

+6
-489
lines changed

6 files changed

+6
-489
lines changed

cosyvoice/llm/llm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2-
# 2025 Alibaba Inc (authors: Xiang Lyu, Yabin Li, Qihua)
2+
# 2025 Alibaba Inc (authors: Xiang Lyu, Yabin Li, Qihua, Shengqiang Li)
33
#
44
# Licensed under the Apache License, Version 2.0 (the "License");
55
# you may not use this file except in compliance with the License.
@@ -420,8 +420,8 @@ def forward_dpo(
420420
rejected_lm_mask = rejected_lm_target == IGNORE_ID
421421
chosen_logps = torch.gather(chosen_logits.log_softmax(dim=-1), dim=2, index=chosen_lm_target.masked_fill(chosen_lm_mask, 0).unsqueeze(dim=-1)).squeeze(dim=-1)
422422
rejected_logps = torch.gather(rejected_logits.log_softmax(dim=-1), dim=2, index=rejected_lm_target.masked_fill(rejected_lm_mask, 0).unsqueeze(dim=-1)).squeeze(dim=-1)
423-
chosen_logps = (chosen_logps * chosen_lm_mask).mean(dim=-1)
424-
rejected_logps = (rejected_logps * chosen_lm_mask).mean(dim=-1)
423+
chosen_logps = (chosen_logps * chosen_lm_mask).sum(dim=-1) / chosen_lm_mask.sum(dim=-1)
424+
rejected_logps = (rejected_logps * rejected_lm_mask).sum(dim=-1) / rejected_lm_mask.sum(dim=-1)
425425
return {'loss': loss, 'acc': acc, 'chosen_logps': chosen_logps, 'rejected_logps': rejected_logps}
426426

427427
@torch.inference_mode()

examples/libritts/cosyvoice/conf/cosyvoice.fromscratch.yaml

Lines changed: 0 additions & 257 deletions
This file was deleted.

0 commit comments

Comments
 (0)