Skip to content

Commit 3746f73

Browse files
committed
fix missing or wrong file during rebase
1 parent 118a66f commit 3746f73

File tree

2 files changed

+6
-20
lines changed

2 files changed

+6
-20
lines changed

applications/ColossalChat/coati/distributed/consumer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,7 @@ def loop(self) -> None:
217217
effective_group_mask = None
218218
if self.filter_range is not None and self.grpo_config.get("dynamic_batching", True):
219219
# filter the group based on the reward and accuracy
220+
group_ans_acc_mean = ans_acc.mean(dim=1)
220221
effective_group_mask = torch.logical_and(
221222
group_ans_acc_mean > self.filter_range[0], group_ans_acc_mean < self.filter_range[1]
222223
)

applications/ColossalChat/coati/distributed/grpo_consumer.py

Lines changed: 5 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,6 @@ def __init__(
9191
self.project_name = project_name
9292
self.effective_sample_count = 0
9393
self.effective_prompt_count = 0
94-
self.total_sample_count = 0
95-
self.overlength_samples = 0
96-
self.total_overlength_samples = 0
9794
self.project_name = project_name
9895
self.run_name = run_name
9996
self.wandb_group_name = wandb_group_name
@@ -207,7 +204,6 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
207204

208205
# filter out overlength samples
209206
if self.filter_truncated_response and action_mask.size(1) == self.max_length:
210-
old_loss_mask = loss_mask.clone()
211207
loss_mask = torch.logical_and(
212208
loss_mask,
213209
action_mask[:, -1] == False,
@@ -225,15 +221,7 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
225221
group_ans_acc_mean < self.filter_range[1],
226222
),
227223
)
228-
self.total_overlength_samples += self.overlength_samples.item()
229-
230-
prompt_level_mask = loss_mask.view(self.minibatch_size, self.num_generations)
231-
232-
# [minibatch_size] -> calculate the number of effective prompts
233-
effective_prompts_mask = prompt_level_mask.any(dim=1)
234-
effective_prompts = all_reduce_sum(torch.sum(effective_prompts_mask), self.plugin)
235-
self.effective_prompt_count += effective_prompts.item()
236-
excessive_prompts_idx = None
224+
self.effective_prompt_count += group_reward.size(0) * self.dp_size
237225

238226
mean_kl, mean_loss = [], []
239227

@@ -250,8 +238,7 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
250238
pbar.set_postfix(
251239
{
252240
"Global Step": self.global_step,
253-
"Effective prompts": f"{self.effective_prompt_count}/{self.batch_size * self.dp_size}",
254-
"Effective samples": f"{self.effective_sample_count}/{self.batch_size * self.dp_size * self.num_generations}",
241+
"Gradient Accumulation on": f"{self.effective_prompt_count}/{self.batch_size * self.dp_size} effective prompts, {self.effective_sample_count}/{self.batch_size * self.dp_size * self.num_generations} effective samples",
255242
}
256243
)
257244

@@ -477,12 +464,10 @@ def _criterion(outputs, inputs):
477464
self.optimizer.step()
478465
self.optimizer.zero_grad()
479466
self.global_step += 1
480-
sample_utilization = self.effective_sample_count / self.total_sample_count
481-
overlength_samples_percentage = self.total_overlength_samples / self.total_sample_count
467+
# no need to run all reduce as raw_train_batch_* are not splited across dp rank
468+
sample_utilization = self.effective_sample_count / len(self.raw_train_batch_reward) / self.num_generations
482469
self.effective_prompt_count = 0
483470
self.effective_sample_count = 0
484-
self.total_sample_count = 0
485-
self.total_overlength_samples = 0
486471
loss_scalar = self.accum_loss.item()
487472
if not self.plugin.pp_size > 1 or (
488473
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
@@ -545,4 +530,4 @@ def state_dict(self):
545530
model = self.policy_model.unwrap()
546531
state_dict = model.state_dict()
547532
state_dict["consumer_global_step"] = torch.tensor([self.global_step], device=self.device)
548-
return state_dict
533+
return state_dict

0 commit comments

Comments
 (0)