Skip to content

Commit 678f5a9

Browse files
author
Tong Li
committed
update loss
1 parent b96d690 commit 678f5a9

File tree

1 file changed

+5
-2
lines changed
  • applications/ColossalChat/coati/distributed

1 file changed

+5
-2
lines changed

applications/ColossalChat/coati/distributed/loss.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,18 @@ class PolicyLoss(nn.Module):
1010
Policy Loss for PPO
1111
"""
1212

13-
def __init__(self, clip_eps: float = 0.2, skip_threshold: float = 20.0) -> None:
13+
def __init__(self, clip_eps: float = 0.2, skip_threshold: float = 20.0, beta: float = 0.01) -> None:
1414
super().__init__()
1515
self.clip_eps = clip_eps
1616
self.skip_threshold = skip_threshold
17+
self.beta = beta
1718

1819
def forward(
1920
self,
2021
log_probs: torch.Tensor,
2122
old_log_probs: torch.Tensor,
2223
advantages: torch.Tensor,
24+
per_token_kl: torch.Tensor,
2325
action_mask: Optional[torch.Tensor] = None,
2426
) -> torch.Tensor:
2527
skip = False
@@ -35,7 +37,8 @@ def forward(
3537
ratio = ratio_.clamp(0.0, 10.0)
3638
surr1 = ratio * advantages
3739
surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages
38-
loss = -torch.min(surr1, surr2)
40+
loss = -torch.min(surr1, surr2) + self.beta * per_token_kl
41+
3942
if action_mask is not None:
4043
loss = masked_mean(loss, action_mask)
4144
else:

0 commit comments

Comments
 (0)