@@ -59,6 +59,7 @@ def __init__(
59
59
self .accum_format_reward = torch .zeros (1 , device = self .device )
60
60
self .accum_acc_reward = torch .zeros (1 , device = self .device )
61
61
self .accum_advantages = torch .zeros (1 , device = self .device )
62
+ self .accum_response_length = torch .zeros (1 , device = self .device )
62
63
self .accum_count = 0
63
64
64
65
# Reference model is initialized from policy model.
@@ -83,7 +84,7 @@ def __init__(
83
84
self .policy_loss_fn = PolicyLoss ()
84
85
self .global_step = 0
85
86
if use_wandb and self .rank == 0 :
86
- self .wandb_run = wandb .init (project = "GRPO-Test " , sync_tensorboard = True )
87
+ self .wandb_run = wandb .init (project = "GRPO-V1 " , sync_tensorboard = True )
87
88
88
89
def setup (self ):
89
90
super ().setup ()
@@ -109,6 +110,7 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]:
109
110
action_mask = data ["action_mask" ]
110
111
num_action = action_mask .shape [1 ]
111
112
old_action_log_probs = data ["action_log_probs" ]
113
+ response_length = torch .sum (action_mask , dim = 1 ).to (torch .float32 )
112
114
113
115
need_update = (step_idx + 1 ) % self .num_microbatches == 0
114
116
@@ -168,13 +170,15 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]:
168
170
format_reward = all_reduce_mean (format_reward .mean (), self .plugin )
169
171
acc_reward = all_reduce_mean (acc_reward .mean (), self .plugin )
170
172
advantages = all_reduce_mean (advantages .mean (), self .plugin )
173
+ response_length = all_reduce_mean (response_length .mean (), self .plugin )
171
174
# Calculate accumulate value.
172
175
self .accum_loss .add_ (loss .data )
173
176
self .accum_reward .add_ (reward .data )
174
177
self .accum_kl .add_ (kl .data )
175
178
self .accum_format_reward .add_ (format_reward .data )
176
179
self .accum_acc_reward .add_ (acc_reward .data )
177
180
self .accum_advantages .add_ (advantages .data )
181
+ self .accum_response_length .add_ (response_length .data )
178
182
self .accum_count += 1
179
183
if need_update :
180
184
self .optimizer .step ()
@@ -184,32 +188,38 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]:
184
188
print (
185
189
"Loss:" ,
186
190
self .accum_loss .item () / self .accum_count ,
187
- "Reward :" ,
191
+ "\n Reward :" ,
188
192
self .accum_reward .item () / self .accum_count ,
189
- "KL:" ,
190
- self .accum_kl .item () / self .accum_count ,
191
- "Format Reward:" ,
193
+ "\n Format Reward:" ,
192
194
self .accum_format_reward .item () / self .accum_count ,
193
- "Acc Reward:" ,
195
+ "\n Acc Reward:" ,
194
196
self .accum_acc_reward .item () / self .accum_count ,
195
- "Advantages:" ,
197
+ "\n KL:" ,
198
+ self .accum_kl .item () / self .accum_count ,
199
+ "\n Advantages:" ,
196
200
self .accum_advantages .item () / self .accum_count ,
201
+ "\n Response Length:" ,
202
+ self .accum_response_length .item () / self .accum_count ,
197
203
)
198
204
self .wandb_run .log (
199
205
{
200
206
"train/loss" : self .accum_loss .item () / self .accum_count ,
201
207
"train/reward" : self .accum_reward .item () / self .accum_count ,
202
- "train/kl" : self .accum_kl .item () / self .accum_count ,
203
208
"train/format_reward" : self .accum_format_reward .item () / self .accum_count ,
204
209
"train/acc_reward" : self .accum_acc_reward .item () / self .accum_count ,
210
+ "train/kl" : self .accum_kl .item () / self .accum_count ,
205
211
"train/advantages" : self .accum_advantages .item () / self .accum_count ,
212
+ "train/response_length" : self .accum_response_length .item () / self .accum_count ,
206
213
}
207
214
)
208
215
self .accum_loss .zero_ ()
209
216
self .accum_reward .zero_ ()
210
- self .accum_kl .zero_ ()
211
217
self .accum_acc_reward .zero_ ()
212
218
self .accum_format_reward .zero_ ()
219
+ self .accum_kl .zero_ ()
220
+ self .accum_advantages .zero_ ()
221
+ self .accum_response_length .zero_ ()
222
+
213
223
self .accum_count = 0
214
224
return loss_scalar
215
225
0 commit comments