Skip to content

Commit a528921

Browse files
committed
move prompt-level-filtering to buffer side
1 parent 55eee12 commit a528921

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

applications/ColossalChat/coati/distributed/grpo_consumer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,6 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
254254
total_effective_tokens_count = all_reduce_sum(torch.sum(effective_tokens_count), self.plugin)
255255
total_samples = all_reduce_sum(torch.sum(torch.ones_like(loss_mask, device=loss_mask.device)), self.plugin)
256256
self.effective_sample_count += effective_samples.item()
257-
self.total_sample_count += total_samples.item()
258257
pbar.set_postfix(
259258
{
260259
"Global Step": self.global_step,
@@ -461,6 +460,9 @@ def _criterion(outputs, inputs):
461460
self.optimizer.step()
462461
self.optimizer.zero_grad()
463462
self.global_step += 1
463+
self.total_sample_count = all_reduce_sum(
464+
torch.tensor(self.total_sample_count).to(self.accum_loss.device), self.plugin
465+
).item()
464466
sample_utilization = self.effective_sample_count / self.total_sample_count
465467
self.effective_prompt_count = 0
466468
self.effective_sample_count = 0
@@ -564,6 +566,7 @@ def prompt_level_filtering(self, rollout_group: Dict[str, Any]) -> Dict[str, Any
564566
"format_acc": torch.Tensor, [num_of_generation]
565567
"ans_acc": torch.Tensor, [num_of_generation]
566568
"""
569+
self.total_sample_count += rollout_group["input_ids"].size(0)
567570
if self.filter_range is not None:
568571
# filter prompt whoes accuracy is too high or too low (out of range)
569572
group_ans_acc = torch.mean(rollout_group["ans_acc"])

0 commit comments

Comments
 (0)