File tree Expand file tree Collapse file tree 1 file changed +3
-8
lines changed
applications/ColossalChat/coati/distributed Expand file tree Collapse file tree 1 file changed +3
-8
lines changed Original file line number Diff line number Diff line change @@ -26,15 +26,10 @@ def forward(
26
26
) -> torch .Tensor :
27
27
skip = False
28
28
if action_mask is None :
29
- ratio_ = (log_probs - old_log_probs ).exp ()
29
+ ratio = (log_probs - log_probs . detach () ).exp ()
30
30
else :
31
- ratio_ = ((log_probs - old_log_probs ) * action_mask ).exp ()
31
+ ratio = ((log_probs - log_probs . detach () ) * action_mask ).exp ()
32
32
33
- # note that if dropout is disabled (recommanded), ratio will always be 1.
34
- if ratio_ .mean () > self .skip_threshold :
35
- skip = True
36
-
37
- ratio = ratio_ .clamp (0.0 , 10.0 )
38
33
surr1 = ratio * advantages
39
34
surr2 = ratio .clamp (1 - self .clip_eps , 1 + self .clip_eps ) * advantages
40
35
loss = - torch .min (surr1 , surr2 ) + self .beta * per_token_kl
@@ -44,4 +39,4 @@ def forward(
44
39
else :
45
40
loss = loss .mean (dim = 1 )
46
41
loss = loss .mean ()
47
- return loss , skip , ratio_ .max ()
42
+ return loss , skip , ratio .max ()
You can’t perform that action at this time.
0 commit comments