Skip to content

Commit 4ec7329

Browse files
committed
use consumer global step
1 parent 094f119 commit 4ec7329

File tree

3 files changed

+7
-5
lines changed

3 files changed

+7
-5
lines changed

applications/ColossalChat/coati/distributed/grpo_consumer.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,6 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
266266
total_samples = all_reduce_sum(torch.sum(torch.ones_like(loss_mask, device=loss_mask.device)), self.plugin)
267267
self.effective_sample_count += effective_samples.item()
268268
self.total_sample_count += total_samples.item()
269-
270269
pbar.set_postfix(
271270
{
272271
"Global Step": self.global_step,
@@ -522,7 +521,6 @@ def _criterion(outputs, inputs):
522521
# All gather excessive prompts index across DP ranks.
523522
excessive_prompts_idx = [idx + self.dp_rank * self.minibatch_size for idx in excessive_prompts_idx]
524523
excessive_prompts_idx = all_gather_tensors(excessive_prompts_idx, self.plugin)
525-
526524
return loss_scalar, excessive_prompts_idx
527525
else:
528526
return None, excessive_prompts_idx

applications/ColossalChat/coati/distributed/launch.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ def launch_distributed(
5656
eval_save_dir: Optional[str] = None,
5757
eval_generation_config: Optional[Dict[str, Any]] = None,
5858
):
59-
6059
if core_algo not in ALGO_MAP:
6160
raise NotImplementedError(f"{core_algo} is not supported yet.")
6261
else:

applications/ColossalChat/coati/distributed/producer.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def __init__(
5858
self.microbatch_size = microbatch_size
5959
assert batch_size % microbatch_size == 0
6060
self.num_microbatches = batch_size // microbatch_size
61+
self.lastest_eval_step = -1
6162

6263
self.train_dataset_config = train_dataset_config
6364
self.model_config = model_config
@@ -178,12 +179,15 @@ def loop(self) -> None:
178179
if i >= num_valid_microbatches:
179180
break
180181
if self.eval_interval > 0 and self.eval_dataset_config is not None:
181-
if i % self.eval_interval == 0:
182+
if (
183+
self.consumer_global_step % self.eval_interval == 0
184+
and self.consumer_global_step > self.lastest_eval_step
185+
):
182186
to_log_msg = {}
183187
for eval_task_name in self.eval_dataloaders:
184188
if self.producer_idx == 0:
185189
print(
186-
f"[P{self.producer_idx}] Evaluate episode {episode} step {i} on task {eval_task_name}"
190+
f"[P{self.producer_idx}] Evaluate episode {episode} step {self.consumer_global_step} on task {eval_task_name}"
187191
)
188192
eval_results = []
189193
eval_statistics_tensor = torch.zeros((2,), dtype=torch.float32).to(self.device)
@@ -223,6 +227,7 @@ def loop(self) -> None:
223227

224228
if self.producer_idx == 0:
225229
self.wandb_run.log(to_log_msg, step=self.consumer_global_step)
230+
self.lastest_eval_step = self.consumer_global_step
226231
outputs = self.rollout(**batch)
227232

228233
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")

0 commit comments

Comments
 (0)