@@ -56,6 +56,9 @@ def __init__(
56
56
self .accum_loss = torch .zeros (1 , device = self .device )
57
57
self .accum_reward = torch .zeros (1 , device = self .device )
58
58
self .accum_kl = torch .zeros (1 , device = self .device )
59
+ self .accum_format_reward = torch .zeros (1 , device = self .device )
60
+ self .accum_acc_reward = torch .zeros (1 , device = self .device )
61
+ self .accum_advantages = torch .zeros (1 , device = self .device )
59
62
self .accum_count = 0
60
63
61
64
# Reference model is initialized from policy model.
@@ -131,9 +134,14 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]:
131
134
)
132
135
kl = torch .sum (per_token_kl * action_mask , dim = - 1 ) / torch .sum (action_mask , dim = - 1 )
133
136
134
- reward = self .reward_model (
137
+ reward_group = self .reward_model (
135
138
data ["input_ids" ], gt_answer = data ["gt_answer" ], response_idx = data ["response_idx" ]
136
139
)
140
+
141
+ reward = torch .tensor ([value [0 ] for value in reward_group ]).to (data ["input_ids" ].device )
142
+ format_reward = torch .tensor ([value [1 ] for value in reward_group ]).to (data ["input_ids" ].device )
143
+ acc_reward = torch .tensor ([value [2 ] for value in reward_group ]).to (data ["input_ids" ].device )
144
+
137
145
# [batch_size, num_generations]
138
146
group_reward = reward .view (- 1 , self .num_generations )
139
147
@@ -157,9 +165,16 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]:
157
165
loss = all_reduce_mean (loss , self .plugin )
158
166
reward = all_reduce_mean (reward .mean (), self .plugin )
159
167
kl = all_reduce_mean (kl .mean (), self .plugin )
168
+ format_reward = all_reduce_mean (format_reward .mean (), self .plugin )
169
+ acc_reward = all_reduce_mean (acc_reward .mean (), self .plugin )
170
+ advantages = all_reduce_mean (advantages .mean (), self .plugin )
171
+ # Calculate accumulate value.
160
172
self .accum_loss .add_ (loss .data )
161
173
self .accum_reward .add_ (reward .data )
162
174
self .accum_kl .add_ (kl .data )
175
+ self .accum_format_reward .add_ (format_reward .data )
176
+ self .accum_acc_reward .add_ (acc_reward .data )
177
+ self .accum_advantages .add_ (advantages .data )
163
178
self .accum_count += 1
164
179
if need_update :
165
180
self .optimizer .step ()
@@ -173,17 +188,28 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]:
173
188
self .accum_reward .item () / self .accum_count ,
174
189
"KL:" ,
175
190
self .accum_kl .item () / self .accum_count ,
191
+ "Format Reward:" ,
192
+ self .accum_format_reward .item () / self .accum_count ,
193
+ "Acc Reward:" ,
194
+ self .accum_acc_reward .item () / self .accum_count ,
195
+ "Advantages:" ,
196
+ self .accum_advantages .item () / self .accum_count ,
176
197
)
177
198
self .wandb_run .log (
178
199
{
179
200
"train/loss" : self .accum_loss .item () / self .accum_count ,
180
201
"train/reward" : self .accum_reward .item () / self .accum_count ,
181
202
"train/kl" : self .accum_kl .item () / self .accum_count ,
203
+ "train/format_reward" : self .accum_format_reward .item () / self .accum_count ,
204
+ "train/acc_reward" : self .accum_acc_reward .item () / self .accum_count ,
205
+ "train/advantages" : self .accum_advantages .item () / self .accum_count ,
182
206
}
183
207
)
184
208
self .accum_loss .zero_ ()
185
209
self .accum_reward .zero_ ()
186
210
self .accum_kl .zero_ ()
211
+ self .accum_acc_reward .zero_ ()
212
+ self .accum_format_reward .zero_ ()
187
213
self .accum_count = 0
188
214
return loss_scalar
189
215
0 commit comments