File tree Expand file tree Collapse file tree 1 file changed +4
-3
lines changed
trinity/algorithm/advantage_fn Expand file tree Collapse file tree 1 file changed +4
-3
lines changed Original file line number Diff line number Diff line change @@ -119,12 +119,13 @@ def calculate_group_advantage(
119119 self , group_id : str , exps : List [Experience ]
120120 ) -> Tuple [List [Experience ], Dict ]:
121121 with torch .no_grad ():
122+ group_rewards = torch .tensor ([exp .reward for exp in exps ], dtype = torch .float32 )
123+ reward_mean = torch .mean (group_rewards )
122124 if len (exps ) == 1 :
123125 group_baseline = torch .tensor (0.0 )
124126 else :
125- group_rewards = torch .tensor ([exp .reward for exp in exps ], dtype = torch .float32 )
126127 if self .opmd_baseline == "mean" :
127- group_baseline = torch . mean ( group_rewards )
128+ group_baseline = reward_mean
128129 else :
129130 group_baseline = self .tau * (
130131 torch .logsumexp (group_rewards / self .tau , dim = - 1 )
@@ -136,7 +137,7 @@ def calculate_group_advantage(
136137 exp .returns = exp .advantages .clone ()
137138 metrics = {
138139 "group_baseline" : group_baseline .item (),
139- "reward_mean" : torch . mean ( group_rewards ) .item (),
140+ "reward_mean" : reward_mean .item (),
140141 }
141142 return exps , metrics
142143
You can’t perform that action at this time.
0 commit comments