@@ -56,6 +56,10 @@ def __init__(
5656 self .accum_loss = torch .zeros (1 , device = self .device )
5757 self .accum_reward = torch .zeros (1 , device = self .device )
5858 self .accum_kl = torch .zeros (1 , device = self .device )
59+ self .accum_format_reward = torch .zeros (1 , device = self .device )
60+ self .accum_acc_reward = torch .zeros (1 , device = self .device )
61+ self .accum_advantages = torch .zeros (1 , device = self .device )
62+ self .accum_response_length = torch .zeros (1 , device = self .device )
5963 self .accum_count = 0
6064
6165 # Reference model is initialized from policy model.
@@ -80,7 +84,7 @@ def __init__(
8084 self .policy_loss_fn = PolicyLoss ()
8185 self .global_step = 0
8286 if use_wandb and self .rank == 0 :
83- self .wandb_run = wandb .init (project = "GRPO-Test " , sync_tensorboard = True )
87+ self .wandb_run = wandb .init (project = "GRPO-V1 " , sync_tensorboard = True )
8488
8589 def setup (self ):
8690 super ().setup ()
@@ -106,6 +110,7 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]:
106110 action_mask = data ["action_mask" ]
107111 num_action = action_mask .shape [1 ]
108112 old_action_log_probs = data ["action_log_probs" ]
113+ response_length = torch .sum (action_mask , dim = 1 ).to (torch .float32 )
109114
110115 need_update = (step_idx + 1 ) % self .num_microbatches == 0
111116
@@ -133,9 +138,14 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]:
133138 )
134139 kl = torch .sum (per_token_kl * action_mask , dim = - 1 ) / torch .sum (action_mask , dim = - 1 )
135140
136- reward = self .reward_model (
141+ reward_group = self .reward_model (
137142 data ["input_ids" ], gt_answer = data ["gt_answer" ], response_idx = data ["response_idx" ]
138143 )
144+
145+ reward = torch .tensor ([value [0 ] for value in reward_group ]).to (data ["input_ids" ].device )
146+ format_reward = torch .tensor ([value [1 ] for value in reward_group ]).to (data ["input_ids" ].device )
147+ acc_reward = torch .tensor ([value [2 ] for value in reward_group ]).to (data ["input_ids" ].device )
148+
139149 # [batch_size, num_generations]
140150 group_reward = reward .view (- 1 , self .num_generations )
141151
@@ -159,9 +169,18 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]:
159169 loss = all_reduce_mean (loss , self .plugin )
160170 reward = all_reduce_mean (reward .mean (), self .plugin )
161171 kl = all_reduce_mean (kl .mean (), self .plugin )
172+ format_reward = all_reduce_mean (format_reward .mean (), self .plugin )
173+ acc_reward = all_reduce_mean (acc_reward .mean (), self .plugin )
174+ advantages = all_reduce_mean (advantages .mean (), self .plugin )
175+ response_length = all_reduce_mean (response_length .mean (), self .plugin )
176+ # Calculate accumulate value.
162177 self .accum_loss .add_ (loss .data )
163178 self .accum_reward .add_ (reward .data )
164179 self .accum_kl .add_ (kl .data )
180+ self .accum_format_reward .add_ (format_reward .data )
181+ self .accum_acc_reward .add_ (acc_reward .data )
182+ self .accum_advantages .add_ (advantages .data )
183+ self .accum_response_length .add_ (response_length .data )
165184 self .accum_count += 1
166185 if need_update :
167186 self .optimizer .step ()
@@ -171,21 +190,38 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]:
171190 print (
172191 "Loss:" ,
173192 self .accum_loss .item () / self .accum_count ,
174- "Reward :" ,
193+ "\n Reward :" ,
175194 self .accum_reward .item () / self .accum_count ,
176- "KL:" ,
195+ "\n Format Reward:" ,
196+ self .accum_format_reward .item () / self .accum_count ,
197+ "\n Acc Reward:" ,
198+ self .accum_acc_reward .item () / self .accum_count ,
199+ "\n KL:" ,
177200 self .accum_kl .item () / self .accum_count ,
201+ "\n Advantages:" ,
202+ self .accum_advantages .item () / self .accum_count ,
203+ "\n Response Length:" ,
204+ self .accum_response_length .item () / self .accum_count ,
178205 )
179206 self .wandb_run .log (
180207 {
181208 "train/loss" : self .accum_loss .item () / self .accum_count ,
182209 "train/reward" : self .accum_reward .item () / self .accum_count ,
210+ "train/format_reward" : self .accum_format_reward .item () / self .accum_count ,
211+ "train/acc_reward" : self .accum_acc_reward .item () / self .accum_count ,
183212 "train/kl" : self .accum_kl .item () / self .accum_count ,
213+ "train/advantages" : self .accum_advantages .item () / self .accum_count ,
214+ "train/response_length" : self .accum_response_length .item () / self .accum_count ,
184215 }
185216 )
186217 self .accum_loss .zero_ ()
187218 self .accum_reward .zero_ ()
219+ self .accum_acc_reward .zero_ ()
220+ self .accum_format_reward .zero_ ()
188221 self .accum_kl .zero_ ()
222+ self .accum_advantages .zero_ ()
223+ self .accum_response_length .zero_ ()
224+
189225 self .accum_count = 0
190226 return loss_scalar
191227
0 commit comments