File tree Expand file tree Collapse file tree 1 file changed +5
-2
lines changed
applications/ColossalChat/coati/distributed Expand file tree Collapse file tree 1 file changed +5
-2
lines changed Original file line number Diff line number Diff line change @@ -10,16 +10,18 @@ class PolicyLoss(nn.Module):
10
10
Policy Loss for PPO
11
11
"""
12
12
13
- def __init__ (self , clip_eps : float = 0.2 , skip_threshold : float = 20.0 ) -> None :
13
+ def __init__ (self , clip_eps : float = 0.2 , skip_threshold : float = 20.0 , beta : float = 0.01 ) -> None :
14
14
super ().__init__ ()
15
15
self .clip_eps = clip_eps
16
16
self .skip_threshold = skip_threshold
17
+ self .beta = beta
17
18
18
19
def forward (
19
20
self ,
20
21
log_probs : torch .Tensor ,
21
22
old_log_probs : torch .Tensor ,
22
23
advantages : torch .Tensor ,
24
+ per_token_kl : torch .Tensor ,
23
25
action_mask : Optional [torch .Tensor ] = None ,
24
26
) -> torch .Tensor :
25
27
skip = False
@@ -35,7 +37,8 @@ def forward(
35
37
ratio = ratio_ .clamp (0.0 , 10.0 )
36
38
surr1 = ratio * advantages
37
39
surr2 = ratio .clamp (1 - self .clip_eps , 1 + self .clip_eps ) * advantages
38
- loss = - torch .min (surr1 , surr2 )
40
+ loss = - torch .min (surr1 , surr2 ) + self .beta * per_token_kl
41
+
39
42
if action_mask is not None :
40
43
loss = masked_mean (loss , action_mask )
41
44
else :
You can’t perform that action at this time.
0 commit comments