@@ -52,10 +52,11 @@ def __init__(
52
52
self .policy_model = AutoModelForCausalLM .from_pretrained (path , ** model_config )
53
53
self .policy_model .train ()
54
54
self .policy_model .gradient_checkpointing_enable ()
55
- self .optimizer = HybridAdam (self .policy_model .parameters (), lr = 1e-4 )
55
+ self .optimizer = HybridAdam (self .policy_model .parameters (), lr = 1e-6 )
56
56
self .accum_loss = torch .zeros (1 , device = self .device )
57
57
self .accum_reward = torch .zeros (1 , device = self .device )
58
58
self .accum_kl = torch .zeros (1 , device = self .device )
59
+ self .accum_count = 0
59
60
60
61
# Reference model is initialized from policy model.
61
62
self .reference_model = AutoModelForCausalLM .from_pretrained (path , ** model_config )
@@ -79,13 +80,7 @@ def __init__(
79
80
self .policy_loss_fn = PolicyLoss ()
80
81
self .global_step = 0
81
82
if self .rank == 0 :
82
- self .wandb_run = wandb .init (project = "Colossal-GRPO-Test6" , sync_tensorboard = True )
83
- # import os
84
- # import time
85
-
86
- # log_dir = self.wandb_run.dir
87
- # # log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()))
88
- # # self.writer = SummaryWriter(log_dir=log_dir)
83
+ self .wandb_run = wandb .init (project = "GRPO-Test" , sync_tensorboard = True )
89
84
90
85
def setup (self ):
91
86
super ().setup ()
@@ -129,66 +124,67 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]:
129
124
)["logits" ]
130
125
reference_action_log_probs = calc_action_log_probs (reference_model_logits , data ["input_ids" ], num_action )
131
126
132
- # GRPO advantage calculation
133
- kl = torch .sum (- 0.1 * (action_log_probs - reference_action_log_probs ) * action_mask , dim = - 1 ) / torch .sum (
134
- action_mask , dim = - 1
127
+ per_token_kl = (
128
+ torch .exp (reference_action_log_probs - action_log_probs )
129
+ - (reference_action_log_probs - action_log_probs )
130
+ - 1
135
131
)
132
+ kl = torch .sum (per_token_kl * action_mask , dim = - 1 ) / torch .sum (action_mask , dim = - 1 )
136
133
137
134
reward = self .reward_model (
138
135
data ["input_ids" ], gt_answer = data ["gt_answer" ], response_idx = data ["response_idx" ]
139
136
)
140
- reward = kl + reward
141
137
# [batch_size, num_generations]
142
138
group_reward = reward .view (- 1 , self .num_generations )
143
139
144
140
# [batch_size x num_generations]
145
141
reward_mean = group_reward .mean (dim = 1 ).repeat_interleave (self .num_generations , dim = 0 )
146
142
reward_std = group_reward .std (dim = 1 ).repeat_interleave (self .num_generations , dim = 0 )
147
143
# [batch_size x num_generations]
148
- advantages = (group_reward .view (- 1 ) - reward_mean ) / (reward_std + 1e-4 )
149
-
150
- # GRPO advantage calculation
151
- kl = torch .sum (- 0.01 * (action_log_probs - reference_action_log_probs ) * action_mask , dim = - 1 ) / torch .sum (
152
- action_mask , dim = - 1
153
- )
144
+ advantages = (reward - reward_mean ) / (reward_std + 1e-4 )
154
145
155
146
# Calculate Loss
156
147
loss , skip_update , _ = self .policy_loss_fn (
157
148
action_log_probs ,
158
149
old_action_log_probs ,
159
150
advantages .unsqueeze (dim = - 1 ).repeat_interleave (action_log_probs .size (- 1 ), dim = - 1 ),
151
+ per_token_kl ,
160
152
action_mask ,
161
153
)
162
154
163
- loss = loss / self .num_microbatches
164
155
if not skip_update :
165
156
self .booster .backward (loss , self .optimizer )
166
- loss = all_reduce_mean (loss )
167
- reward = all_reduce_mean (reward .mean ())
168
- kl = all_reduce_mean (kl .mean ())
157
+ loss = all_reduce_mean (loss , self . plugin )
158
+ reward = all_reduce_mean (reward .mean (), self . plugin )
159
+ kl = all_reduce_mean (kl .mean (), self . plugin )
169
160
self .accum_loss .add_ (loss .data )
170
161
self .accum_reward .add_ (reward .data )
171
162
self .accum_kl .add_ (kl .data )
163
+ self .accum_count += 1
172
164
if need_update :
173
165
self .optimizer .step ()
174
166
self .optimizer .zero_grad ()
175
167
loss_scalar = self .accum_loss .item ()
176
168
if self .rank == 0 :
177
- print ("Loss:" , self .accum_loss .item (), "Reward:" , self .accum_reward .item (), "KL:" , self .accum_kl .item ())
169
+ print (
170
+ "Loss:" ,
171
+ self .accum_loss .item () / self .accum_count ,
172
+ "Reward:" ,
173
+ self .accum_reward .item () / self .accum_count ,
174
+ "KL:" ,
175
+ self .accum_kl .item () / self .accum_count ,
176
+ )
178
177
self .wandb_run .log (
179
178
{
180
- "train/loss" : self .accum_loss .item (),
181
- "train/reward" : self .accum_reward .item (),
182
- "train/kl" : self .accum_kl .item (),
179
+ "train/loss" : self .accum_loss .item () / self . accum_count ,
180
+ "train/reward" : self .accum_reward .item () / self . accum_count ,
181
+ "train/kl" : self .accum_kl .item () / self . accum_count ,
183
182
}
184
183
)
185
- # self.writer.add_scalar("train/loss", self.accum_loss.item(), self.global_step)
186
- # self.writer.add_scalar("train/reward", self.accum_reward.item(), self.global_step)
187
- # self.writer.add_scalar("train/kl", self.accum_kl.item(), self.global_step)
188
- # self.global_step += 1
189
184
self .accum_loss .zero_ ()
190
185
self .accum_reward .zero_ ()
191
186
self .accum_kl .zero_ ()
187
+ self .accum_count = 0
192
188
return loss_scalar
193
189
194
190
def state_dict (self ):
0 commit comments