1
+ import json
2
+ import os
1
3
from contextlib import nullcontext
2
4
from typing import Optional
3
5
4
6
import ray
5
7
import torch
8
+ import torch .distributed as dist
6
9
import wandb
7
10
from coati .distributed .consumer import BaseConsumer
8
11
from coati .distributed .loss import PolicyLoss
12
15
from coati .trainer .utils import all_reduce_mean
13
16
from transformers import AutoModelForCausalLM , AutoTokenizer
14
17
18
+ from colossalai .nn .lr_scheduler import CosineAnnealingWarmupLR
15
19
from colossalai .nn .optimizer import HybridAdam
16
20
17
21
@@ -31,8 +35,10 @@ def __init__(
31
35
model_config ,
32
36
plugin_config ,
33
37
microbatch_size = 1 ,
34
- num_generations = 4 ,
38
+ num_generations = 8 ,
35
39
use_wandb = True ,
40
+ generate_config = None ,
41
+ training_config = {},
36
42
):
37
43
super ().__init__ (
38
44
num_producers ,
@@ -52,7 +58,7 @@ def __init__(
52
58
self .policy_model = AutoModelForCausalLM .from_pretrained (path , ** model_config )
53
59
self .policy_model .train ()
54
60
self .policy_model .gradient_checkpointing_enable ()
55
- self .optimizer = HybridAdam (self .policy_model .parameters (), lr = 1e-6 )
61
+ self .optimizer = HybridAdam (self .policy_model .parameters (), lr = training_config . get ( "lr" , 1e-6 ) )
56
62
self .accum_loss = torch .zeros (1 , device = self .device )
57
63
self .accum_reward = torch .zeros (1 , device = self .device )
58
64
self .accum_kl = torch .zeros (1 , device = self .device )
@@ -61,6 +67,7 @@ def __init__(
61
67
self .accum_advantages = torch .zeros (1 , device = self .device )
62
68
self .accum_response_length = torch .zeros (1 , device = self .device )
63
69
self .accum_count = 0
70
+ self .generate_config = generate_config
64
71
65
72
# Reference model is initialized from policy model.
66
73
self .reference_model = AutoModelForCausalLM .from_pretrained (path , ** model_config )
@@ -69,6 +76,9 @@ def __init__(
69
76
self .tokenizer = AutoTokenizer .from_pretrained (path )
70
77
self .pad_token_id = self .tokenizer .pad_token_id
71
78
self .num_generations = num_generations
79
+ self .filter_range = training_config .get ("filter_range" , None )
80
+ if self .filter_range is not None :
81
+ assert len (self .filter_range ) == 2 , "Filter range should have 2 values."
72
82
73
83
# Initialize verifiable reward.
74
84
response_format_tags = {
@@ -84,11 +94,21 @@ def __init__(
84
94
self .policy_loss_fn = PolicyLoss ()
85
95
self .global_step = 0
86
96
if use_wandb and self .rank == 0 :
87
- self .wandb_run = wandb .init (project = "GRPO-V1" , sync_tensorboard = True )
97
+ name = f"{ generate_config ['backend' ]} _bs_{ self .batch_size * self .world_size } _temp_{ generate_config ['temperature' ]:.01f} _top_p_{ generate_config ['top_p' ]:.02f} "
98
+ self .wandb_run = wandb .init (project = "GRPO-V1" , sync_tensorboard = True , dir = "./wandb" , name = name )
99
+
100
+ self .lr_scheduler = CosineAnnealingWarmupLR (
101
+ optimizer = self .optimizer ,
102
+ total_steps = min (self .num_episodes , 4 ) * self .num_update_per_episode ,
103
+ warmup_steps = 0 ,
104
+ eta_min = 0.1 * training_config .get ("lr" , 1e-6 ),
105
+ )
88
106
89
107
def setup (self ):
90
108
super ().setup ()
91
- self .policy_model , self .optimizer , * _ = self .booster .boost (self .policy_model , self .optimizer )
109
+ self .policy_model , self .optimizer , _ , _ , self .lr_scheduler = self .booster .boost (
110
+ self .policy_model , self .optimizer , lr_scheduler = self .lr_scheduler
111
+ )
92
112
self .reference_model , * _ = self .booster .boost (self .reference_model )
93
113
94
114
def step (self , step_idx : int , ** kwargs ) -> Optional [float ]:
@@ -113,15 +133,17 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]:
113
133
response_length = torch .sum (action_mask , dim = 1 ).to (torch .float32 )
114
134
115
135
need_update = (step_idx + 1 ) % self .num_microbatches == 0
116
-
117
136
ctx = nullcontext () if need_update else self .booster .no_sync (self .policy_model , self .optimizer )
118
137
with ctx :
119
138
policy_model_logits = self .policy_model (
120
139
input_ids = data ["input_ids" ],
121
140
attention_mask = data ["attention_mask" ],
122
141
)["logits" ]
123
142
action_log_probs = calc_action_log_probs (
124
- policy_model_logits , data ["input_ids" ], num_action , self .plugin .shard_config
143
+ policy_model_logits / self .generate_config ["temperature" ],
144
+ data ["input_ids" ],
145
+ num_action ,
146
+ self .plugin .shard_config ,
125
147
)
126
148
127
149
with torch .no_grad ():
@@ -130,7 +152,10 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]:
130
152
attention_mask = data ["attention_mask" ],
131
153
)["logits" ]
132
154
reference_action_log_probs = calc_action_log_probs (
133
- reference_model_logits , data ["input_ids" ], num_action , self .plugin .shard_config
155
+ reference_model_logits / self .generate_config ["temperature" ],
156
+ data ["input_ids" ],
157
+ num_action ,
158
+ self .plugin .shard_config ,
134
159
)
135
160
136
161
per_token_kl = (
@@ -149,21 +174,31 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]:
149
174
acc_reward = torch .tensor ([value [2 ] for value in reward_group ]).to (data ["input_ids" ].device )
150
175
151
176
# [batch_size, num_generations]
177
+
152
178
group_reward = reward .view (- 1 , self .num_generations )
179
+ reward_mean = group_reward .mean (dim = 1 )
180
+ # filter out the reward that is too high (all sample gets full score) or too low (all sample gets 0 score),
181
+ loss_mask = (
182
+ None
183
+ if self .filter_range is None
184
+ else torch .logical_and (
185
+ reward_mean > self .filter_range [0 ], reward_mean < self .filter_range [1 ]
186
+ ).repeat_interleave (self .num_generations , dim = 0 )
187
+ )
153
188
154
189
# [batch_size x num_generations]
155
- reward_mean = group_reward . mean ( dim = 1 ) .repeat_interleave (self .num_generations , dim = 0 )
190
+ reward_mean = reward_mean .repeat_interleave (self .num_generations , dim = 0 )
156
191
reward_std = group_reward .std (dim = 1 ).repeat_interleave (self .num_generations , dim = 0 )
157
192
# [batch_size x num_generations]
158
193
advantages = (reward - reward_mean ) / (reward_std + 1e-4 )
159
194
160
- # Calculate Loss
161
195
loss , skip_update , _ = self .policy_loss_fn (
162
196
action_log_probs ,
163
197
old_action_log_probs ,
164
198
advantages .unsqueeze (dim = - 1 ).repeat_interleave (action_log_probs .size (- 1 ), dim = - 1 ),
165
199
per_token_kl ,
166
200
action_mask ,
201
+ loss_mask = loss_mask ,
167
202
)
168
203
169
204
if not skip_update :
@@ -207,13 +242,15 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]:
207
242
)
208
243
self .wandb_run .log (
209
244
{
245
+ "metrics/reward" : self .accum_reward .item () / self .accum_count ,
246
+ "metrics/format_reward" : self .accum_format_reward .item () / self .accum_count ,
247
+ "metrics/acc_reward" : self .accum_acc_reward .item () / self .accum_count ,
248
+ "metrics/response_length" : self .accum_response_length .item () / self .accum_count ,
210
249
"train/loss" : self .accum_loss .item () / self .accum_count ,
211
- "train/reward" : self .accum_reward .item () / self .accum_count ,
212
- "train/format_reward" : self .accum_format_reward .item () / self .accum_count ,
213
- "train/acc_reward" : self .accum_acc_reward .item () / self .accum_count ,
214
250
"train/kl" : self .accum_kl .item () / self .accum_count ,
215
251
"train/advantages" : self .accum_advantages .item () / self .accum_count ,
216
- "train/response_length" : self .accum_response_length .item () / self .accum_count ,
252
+ "train/learning_rate" : self .lr_scheduler .get_last_lr ()[0 ],
253
+ "rollout/temperature" : data ["temperature" ].cpu ().numpy ()[0 ][0 ],
217
254
}
218
255
)
219
256
self .accum_loss .zero_ ()
@@ -232,3 +269,125 @@ def state_dict(self):
232
269
model = self .policy_model .unwrap ()
233
270
state_dict = model .state_dict ()
234
271
return state_dict
272
+
273
+
274
+ @ray .remote
275
+ class GRPOEvalConsumer (BaseConsumer ):
276
+ def __init__ (
277
+ self ,
278
+ num_producers ,
279
+ num_episodes ,
280
+ rank ,
281
+ world_size ,
282
+ master_addr ,
283
+ master_port ,
284
+ num_update_per_episode ,
285
+ num_recv_per_update ,
286
+ batch_size ,
287
+ model_config ,
288
+ plugin_config ,
289
+ microbatch_size = 1 ,
290
+ num_generations = 4 ,
291
+ use_wandb = True ,
292
+ log_dir = "./results" ,
293
+ ):
294
+ super ().__init__ (
295
+ num_producers ,
296
+ num_episodes ,
297
+ rank ,
298
+ world_size ,
299
+ master_addr ,
300
+ master_port ,
301
+ num_update_per_episode ,
302
+ num_recv_per_update ,
303
+ batch_size ,
304
+ model_config ,
305
+ plugin_config ,
306
+ microbatch_size ,
307
+ )
308
+ path = model_config .pop ("path" )
309
+ self .policy_model = AutoModelForCausalLM .from_pretrained (path , ** model_config )
310
+ self .policy_model .train ()
311
+ self .accum_reward = torch .zeros (1 , device = self .device )
312
+ self .accum_format_reward = torch .zeros (1 , device = self .device )
313
+ self .accum_acc_reward = torch .zeros (1 , device = self .device )
314
+ self .accum_response_length = torch .zeros (1 , device = self .device )
315
+ self .accum_count = torch .zeros (1 , device = self .device )
316
+
317
+ self .tokenizer = AutoTokenizer .from_pretrained (path )
318
+ self .pad_token_id = self .tokenizer .pad_token_id
319
+ self .num_generations = num_generations
320
+
321
+ # Initialize verifiable reward.
322
+ response_format_tags = {
323
+ "think_start" : {"text" : "<think>" , "num_occur" : 1 },
324
+ "think_end" : {"text" : "</think>" , "num_occur" : 1 },
325
+ "answer_start" : {"text" : "<answer>" , "num_occur" : 1 },
326
+ "answer_end" : {"text" : "</answer>" , "num_occur" : 1 },
327
+ }
328
+ self .reward_model = VerifiableReward (
329
+ reward_fns = [math_reward_fn ], tokenizer = self .tokenizer , tags = response_format_tags
330
+ )
331
+
332
+ self .log_dir = log_dir
333
+ if not os .path .exists (self .log_dir ):
334
+ os .makedirs (self .log_dir )
335
+ else :
336
+ os .system (f"rm -rf { self .log_dir } /*" )
337
+
338
+ def setup (self ):
339
+ super ().setup ()
340
+ self .policy_model , _ , * _ = self .booster .boost (self .policy_model )
341
+
342
+ def step (self , step_idx : int , ** kwargs ) -> Optional [float ]:
343
+ rank = dist .get_rank ()
344
+ data = {k : v .view (- 1 , v .size (- 1 )).cpu () for k , v in kwargs .items ()}
345
+ kwargs ["input_ids" ].size (0 )
346
+ reward_group = self .reward_model (
347
+ data ["input_ids" ], gt_answer = data ["gt_answer" ], response_idx = data ["response_idx" ]
348
+ )
349
+ reward = [value [0 ].item () for value in reward_group ]
350
+ format_reward = [value [1 ].item () for value in reward_group ]
351
+ acc_reward = [value [2 ].item () for value in reward_group ]
352
+ response_length = [(data ["response_idx" ][i ][1 ] - data ["response_idx" ][i ][0 ]).item () for i in range (len (reward ))]
353
+
354
+ response = self .tokenizer .batch_decode (data ["input_ids" ], skip_special_tokens = True )
355
+ with open (f"{ self .log_dir } /eval_results_rank_{ rank } .jsonl" , "a" , encoding = "utf8" ) as f :
356
+ for i in range (len (response )):
357
+ f .write (
358
+ json .dumps (
359
+ {
360
+ "response" : response [i ],
361
+ "reward" : reward [i ],
362
+ "format_reward" : format_reward [i ],
363
+ "acc_reward" : acc_reward [i ],
364
+ "response_length" : response_length [i ],
365
+ },
366
+ ensure_ascii = False ,
367
+ )
368
+ + "\n "
369
+ )
370
+
371
+ self .accum_reward += sum (reward )
372
+ self .accum_format_reward += sum (format_reward )
373
+ self .accum_acc_reward += sum (acc_reward )
374
+ self .accum_response_length += sum (response_length )
375
+ self .accum_count += len (reward )
376
+
377
+ # print results
378
+ total_count = all_reduce_mean (self .accum_count , self .plugin )
379
+ mean_reward = all_reduce_mean (self .accum_reward , self .plugin ) / total_count
380
+ mean_format_reward = all_reduce_mean (self .accum_format_reward , self .plugin ) / total_count
381
+ mean_acc_reward = all_reduce_mean (self .accum_acc_reward , self .plugin ) / total_count
382
+ mean_response_length = all_reduce_mean (self .accum_response_length , self .plugin ) / total_count
383
+ if rank == 0 :
384
+ print (
385
+ f"Step { step_idx } : Mean Reward: { mean_reward } , Mean Format Reward: { mean_format_reward } , Mean Acc Reward: { mean_acc_reward } , Mean Response Length: { mean_response_length } "
386
+ )
387
+ return None
388
+
389
+ def state_dict (self ):
390
+ self .policy_model ._force_wait_all_gather ()
391
+ model = self .policy_model .unwrap ()
392
+ state_dict = model .state_dict ()
393
+ return state_dict
0 commit comments