Skip to content

Commit b9e6055

Browse files
authored
Merge pull request #6208 from hpcaitech/grpo_dev
[Chat] fix colossalchat bugs
2 parents 9379cbd + 7595c45 commit b9e6055

File tree

8 files changed

+10
-10
lines changed

8 files changed

+10
-10
lines changed

applications/ColossalChat/coati/experience_maker/naive.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def make_experience(
140140
num_actions = 0
141141

142142
for inference_mini_batch_id in range(0, input_ids.size(0), self.inference_batch_size):
143-
s, e = inference_mini_batch_id, (inference_mini_batch_id + 1) * self.inference_batch_size
143+
s, e = inference_mini_batch_id, inference_mini_batch_id + self.inference_batch_size
144144
if input_ids[s:e].size(0) == 0:
145145
break
146146
sequences = generate(self.actor, input_ids[s:e], self.tokenizer, **generate_kwargs)

applications/ColossalChat/coati/trainer/dpo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -380,8 +380,8 @@ def _criterion(outputs, inputs):
380380
self.accumulative_meter.get("accuracy"),
381381
global_step,
382382
)
383-
self.num_train_step += 1
384383
self.accumulative_meter.reset()
384+
self.num_train_step += 1
385385

386386
if self.save_dir is not None and self.num_train_step > 0 and self.num_train_step % self.save_interval == 0:
387387
# save checkpoint

applications/ColossalChat/coati/trainer/grpo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,6 @@ def _training_step(self, experience: Experience):
231231
experience:
232232
sequences: [batch_size, prompt_length + response_length] --- <PAD>...<PAD><PROMPT>...<PROMPT><RESPONSE>...<RESPONSE><PAD>...<PAD>
233233
"""
234-
self.num_train_step += 1
235234
self.actor.train()
236235
num_actions = experience.action_log_probs.size(1)
237236
# policy loss
@@ -294,7 +293,7 @@ def _training_step(self, experience: Experience):
294293
self.temperature_annealing_scheduler.step_forward()
295294

296295
# preparing logging model output and corresponding rewards.
297-
if self.num_train_step % 10 == 1:
296+
if self.num_train_step % 10 == 0:
298297
response_text = self.experience_maker.tokenizer.batch_decode(
299298
experience.sequences, skip_special_tokens=True
300299
)
@@ -327,6 +326,7 @@ def _training_step(self, experience: Experience):
327326
self.writer.add_scalar("approx_kl", self.accumulative_meter.get("kl"), global_step)
328327
self.writer.add_scalar("advantages", self.accumulative_meter.get("advantages"), global_step)
329328
self.accumulative_meter.reset()
329+
self.num_train_step += 1
330330

331331
def _learn(self, update_step: int):
332332
"""

applications/ColossalChat/coati/trainer/kto.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ def _train(self, epoch: int):
256256
self.coordinator.print_on_master(
257257
f"Saved checkpoint at epoch {epoch} step {self.save_interval} at folder {self.save_dir}"
258258
)
259-
self.num_train_step += 1
259+
self.num_train_step += 1
260260

261261
step_bar.close()
262262

applications/ColossalChat/coati/trainer/orpo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ def _train(self, epoch: int):
233233
self.coordinator.print_on_master(
234234
f"Saved checkpoint at epoch {epoch} step {self.save_interval} at folder {self.save_dir}"
235235
)
236-
self.num_train_step += 1
236+
self.num_train_step += 1
237237

238238
step_bar.close()
239239

applications/ColossalChat/coati/trainer/ppo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,6 @@ def _training_step(self, experience: Experience):
220220
experience:
221221
sequences: [batch_size, prompt_length + response_length] --- <PAD>...<PAD><PROMPT>...<PROMPT><RESPONSE>...<RESPONSE><PAD>...<PAD>
222222
"""
223-
self.num_train_step += 1
224223
self.actor.train()
225224
self.critic.train()
226225
num_actions = experience.action_log_probs.size(1)
@@ -294,7 +293,7 @@ def _training_step(self, experience: Experience):
294293
self.critic_scheduler.step()
295294

296295
# preparing logging model output and corresponding rewards.
297-
if self.num_train_step % 10 == 1:
296+
if self.num_train_step % 10 == 0:
298297
response_text = self.experience_maker.tokenizer.batch_decode(
299298
experience.sequences, skip_special_tokens=True
300299
)
@@ -336,6 +335,7 @@ def _training_step(self, experience: Experience):
336335
self.writer.add_scalar("value", self.accumulative_meter.get("value"), self.num_train_step)
337336
self.writer.add_scalar("advantages", self.accumulative_meter.get("advantages"), self.num_train_step)
338337
self.accumulative_meter.reset()
338+
self.num_train_step += 1
339339

340340
def _learn(self, update_step: int):
341341
"""

applications/ColossalChat/coati/trainer/rm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def _train(self, epoch):
193193
self.coordinator.print_on_master(
194194
f"Saved checkpoint at epoch {epoch} step {(i + 1)/self.accumulation_steps} at folder {self.save_dir}"
195195
)
196-
self.num_train_step += 1
196+
self.num_train_step += 1
197197
step_bar.close()
198198

199199
def _eval(self, epoch):

applications/ColossalChat/coati/trainer/sft.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,9 +152,9 @@ def _train(self, epoch: int):
152152
if self.writer:
153153
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), global_step)
154154
self.writer.add_scalar("train/lr", self.scheduler.get_last_lr()[0], global_step)
155-
self.num_train_step += 1
156155
self.accumulative_meter.reset()
157156
step_bar.update()
157+
self.num_train_step += 1
158158

159159
# Save checkpoint
160160
if (

0 commit comments

Comments
 (0)