Skip to content

Commit 704866a

Browse files
author
Tong Li
committed
detach
1 parent 47d6493 commit 704866a

File tree

1 file changed

+3
-8
lines changed
  • applications/ColossalChat/coati/distributed

1 file changed

+3
-8
lines changed

applications/ColossalChat/coati/distributed/loss.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,10 @@ def forward(
2626
) -> torch.Tensor:
2727
skip = False
2828
if action_mask is None:
29-
ratio_ = (log_probs - old_log_probs).exp()
29+
ratio = (log_probs - log_probs.detach()).exp()
3030
else:
31-
ratio_ = ((log_probs - old_log_probs) * action_mask).exp()
31+
ratio = ((log_probs - log_probs.detach()) * action_mask).exp()
3232

33-
# note that if dropout is disabled (recommanded), ratio will always be 1.
34-
if ratio_.mean() > self.skip_threshold:
35-
skip = True
36-
37-
ratio = ratio_.clamp(0.0, 10.0)
3833
surr1 = ratio * advantages
3934
surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages
4035
loss = -torch.min(surr1, surr2) + self.beta * per_token_kl
@@ -44,4 +39,4 @@ def forward(
4439
else:
4540
loss = loss.mean(dim=1)
4641
loss = loss.mean()
47-
return loss, skip, ratio_.max()
42+
return loss, skip, ratio.max()

0 commit comments

Comments
 (0)