1
+ import json
2
+ import os
1
3
from contextlib import nullcontext
2
4
from typing import Optional
3
5
4
6
import ray
5
7
import torch
8
+ import torch .distributed as dist
6
9
import wandb
7
10
from coati .distributed .consumer import BaseConsumer
8
11
from coati .distributed .loss import PolicyLoss
@@ -33,6 +36,8 @@ def __init__(
33
36
microbatch_size = 1 ,
34
37
num_generations = 4 ,
35
38
use_wandb = True ,
39
+ generator_config = None ,
40
+ filter_range = None ,
36
41
):
37
42
super ().__init__ (
38
43
num_producers ,
@@ -69,6 +74,9 @@ def __init__(
69
74
self .tokenizer = AutoTokenizer .from_pretrained (path )
70
75
self .pad_token_id = self .tokenizer .pad_token_id
71
76
self .num_generations = num_generations
77
+ self .filter_range = filter_range
78
+ if self .filter_range is not None :
79
+ assert len (self .filter_range ) == 2 , "Filter range should have 2 values."
72
80
73
81
# Initialize verifiable reward.
74
82
response_format_tags = {
@@ -84,7 +92,11 @@ def __init__(
84
92
self .policy_loss_fn = PolicyLoss ()
85
93
self .global_step = 0
86
94
if use_wandb and self .rank == 0 :
87
- self .wandb_run = wandb .init (project = "GRPO-V1" , sync_tensorboard = True )
95
+ if "repetition_penalty" in generator_config :
96
+ name = f"{ generator_config ['backend' ]} _bs_{ self .batch_size * self .world_size } _temp_{ generator_config ['temperature' ]:.01f} _rep_penalty_{ generator_config ['repetition_penalty' ]:.01f} "
97
+ else :
98
+ name = f"{ generator_config ['backend' ]} _bs_{ self .batch_size * self .world_size } _temp_{ generator_config ['temperature' ]:.01f} "
99
+ self .wandb_run = wandb .init (project = "GRPO-V1" , sync_tensorboard = True , dir = "./wandb" , name = name )
88
100
89
101
def setup (self ):
90
102
super ().setup ()
@@ -121,7 +133,7 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]:
121
133
attention_mask = data ["attention_mask" ],
122
134
)["logits" ]
123
135
action_log_probs = calc_action_log_probs (
124
- policy_model_logits , data ["input_ids" ], num_action , self .plugin .shard_config
136
+ policy_model_logits / generator_config [ "temperature" ] , data ["input_ids" ], num_action , self .plugin .shard_config
125
137
)
126
138
127
139
with torch .no_grad ():
@@ -130,7 +142,7 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]:
130
142
attention_mask = data ["attention_mask" ],
131
143
)["logits" ]
132
144
reference_action_log_probs = calc_action_log_probs (
133
- reference_model_logits , data ["input_ids" ], num_action , self .plugin .shard_config
145
+ reference_model_logits / generator_config [ "temperature" ] , data ["input_ids" ], num_action , self .plugin .shard_config
134
146
)
135
147
136
148
per_token_kl = (
@@ -149,7 +161,14 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]:
149
161
acc_reward = torch .tensor ([value [2 ] for value in reward_group ]).to (data ["input_ids" ].device )
150
162
151
163
# [batch_size, num_generations]
164
+ # filter out the reward that is too high (all sample gets full score) or too low (all sample gets 0 score),
165
+ loss_mask = (
166
+ None
167
+ if self .filter_range is None
168
+ else torch .logical_and (reward > self .filter_range [0 ], reward < self .filter_range [1 ])
169
+ )
152
170
group_reward = reward .view (- 1 , self .num_generations )
171
+ reward_mean = group_reward .mean (dim = 1 )
153
172
154
173
# [batch_size x num_generations]
155
174
reward_mean = group_reward .mean (dim = 1 ).repeat_interleave (self .num_generations , dim = 0 )
@@ -164,6 +183,7 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]:
164
183
advantages .unsqueeze (dim = - 1 ).repeat_interleave (action_log_probs .size (- 1 ), dim = - 1 ),
165
184
per_token_kl ,
166
185
action_mask ,
186
+ loss_mask = loss_mask ,
167
187
)
168
188
169
189
if not skip_update :
@@ -232,3 +252,125 @@ def state_dict(self):
232
252
model = self .policy_model .unwrap ()
233
253
state_dict = model .state_dict ()
234
254
return state_dict
255
+
256
+
257
+ @ray .remote
258
+ class GRPOEvalConsumer (BaseConsumer ):
259
+ def __init__ (
260
+ self ,
261
+ num_producers ,
262
+ num_episodes ,
263
+ rank ,
264
+ world_size ,
265
+ master_addr ,
266
+ master_port ,
267
+ num_update_per_episode ,
268
+ num_recv_per_update ,
269
+ batch_size ,
270
+ model_config ,
271
+ plugin_config ,
272
+ microbatch_size = 1 ,
273
+ num_generations = 4 ,
274
+ use_wandb = True ,
275
+ log_dir = "./results" ,
276
+ ):
277
+ super ().__init__ (
278
+ num_producers ,
279
+ num_episodes ,
280
+ rank ,
281
+ world_size ,
282
+ master_addr ,
283
+ master_port ,
284
+ num_update_per_episode ,
285
+ num_recv_per_update ,
286
+ batch_size ,
287
+ model_config ,
288
+ plugin_config ,
289
+ microbatch_size ,
290
+ )
291
+ path = model_config .pop ("path" )
292
+ self .policy_model = AutoModelForCausalLM .from_pretrained (path , ** model_config )
293
+ self .policy_model .train ()
294
+ self .accum_reward = torch .zeros (1 , device = self .device )
295
+ self .accum_format_reward = torch .zeros (1 , device = self .device )
296
+ self .accum_acc_reward = torch .zeros (1 , device = self .device )
297
+ self .accum_response_length = torch .zeros (1 , device = self .device )
298
+ self .accum_count = torch .zeros (1 , device = self .device )
299
+
300
+ self .tokenizer = AutoTokenizer .from_pretrained (path )
301
+ self .pad_token_id = self .tokenizer .pad_token_id
302
+ self .num_generations = num_generations
303
+
304
+ # Initialize verifiable reward.
305
+ response_format_tags = {
306
+ "think_start" : {"text" : "<think>" , "num_occur" : 1 },
307
+ "think_end" : {"text" : "</think>" , "num_occur" : 1 },
308
+ "answer_start" : {"text" : "<answer>" , "num_occur" : 1 },
309
+ "answer_end" : {"text" : "</answer>" , "num_occur" : 1 },
310
+ }
311
+ self .reward_model = VerifiableReward (
312
+ reward_fns = [math_reward_fn ], tokenizer = self .tokenizer , tags = response_format_tags
313
+ )
314
+
315
+ self .log_dir = log_dir
316
+ if not os .path .exists (self .log_dir ):
317
+ os .makedirs (self .log_dir )
318
+ else :
319
+ os .system (f"rm -rf { self .log_dir } /*" )
320
+
321
+ def setup (self ):
322
+ super ().setup ()
323
+ self .policy_model , _ , * _ = self .booster .boost (self .policy_model )
324
+
325
+ def step (self , step_idx : int , ** kwargs ) -> Optional [float ]:
326
+ rank = dist .get_rank ()
327
+ data = {k : v .view (- 1 , v .size (- 1 )).cpu () for k , v in kwargs .items ()}
328
+ kwargs ["input_ids" ].size (0 )
329
+ reward_group = self .reward_model (
330
+ data ["input_ids" ], gt_answer = data ["gt_answer" ], response_idx = data ["response_idx" ]
331
+ )
332
+ reward = [value [0 ].item () for value in reward_group ]
333
+ format_reward = [value [1 ].item () for value in reward_group ]
334
+ acc_reward = [value [2 ].item () for value in reward_group ]
335
+ response_length = [(data ["response_idx" ][i ][1 ] - data ["response_idx" ][i ][0 ]).item () for i in range (len (reward ))]
336
+
337
+ response = self .tokenizer .batch_decode (data ["input_ids" ], skip_special_tokens = True )
338
+ with open (f"{ self .log_dir } /eval_results_rank_{ rank } .jsonl" , "a" , encoding = "utf8" ) as f :
339
+ for i in range (len (response )):
340
+ f .write (
341
+ json .dumps (
342
+ {
343
+ "response" : response [i ],
344
+ "reward" : reward [i ],
345
+ "format_reward" : format_reward [i ],
346
+ "acc_reward" : acc_reward [i ],
347
+ "response_length" : response_length [i ],
348
+ },
349
+ ensure_ascii = False ,
350
+ )
351
+ + "\n "
352
+ )
353
+
354
+ self .accum_reward += sum (reward )
355
+ self .accum_format_reward += sum (format_reward )
356
+ self .accum_acc_reward += sum (acc_reward )
357
+ self .accum_response_length += sum (response_length )
358
+ self .accum_count += len (reward )
359
+
360
+ # print results
361
+ total_count = all_reduce_mean (self .accum_count , self .plugin )
362
+ mean_reward = all_reduce_mean (self .accum_reward , self .plugin ) / total_count
363
+ mean_format_reward = all_reduce_mean (self .accum_format_reward , self .plugin ) / total_count
364
+ mean_acc_reward = all_reduce_mean (self .accum_acc_reward , self .plugin ) / total_count
365
+ mean_response_length = all_reduce_mean (self .accum_response_length , self .plugin ) / total_count
366
+ if rank == 0 :
367
+ print (
368
+ f"Step { step_idx } : Mean Reward: { mean_reward } , Mean Format Reward: { mean_format_reward } , Mean Acc Reward: { mean_acc_reward } , Mean Response Length: { mean_response_length } "
369
+ )
370
+ return None
371
+
372
+ def state_dict (self ):
373
+ self .policy_model ._force_wait_all_gather ()
374
+ model = self .policy_model .unwrap ()
375
+ state_dict = model .state_dict ()
376
+ return state_dict
0 commit comments