Skip to content

Commit 2aa7385

Browse files
committed
update logging
1 parent d8eaf0d commit 2aa7385

File tree

2 files changed

+9
-5
lines changed

2 files changed

+9
-5
lines changed

applications/ColossalChat/coati/distributed/grpo_consumer.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,6 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]:
133133
response_length = torch.sum(action_mask, dim=1).to(torch.float32)
134134

135135
need_update = (step_idx + 1) % self.num_microbatches == 0
136-
137136
ctx = nullcontext() if need_update else self.booster.no_sync(self.policy_model, self.optimizer)
138137
with ctx:
139138
policy_model_logits = self.policy_model(
@@ -243,13 +242,15 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]:
243242
)
244243
self.wandb_run.log(
245244
{
245+
"metrics/reward": self.accum_reward.item() / self.accum_count,
246+
"metrics/format_reward": self.accum_format_reward.item() / self.accum_count,
247+
"metrics/acc_reward": self.accum_acc_reward.item() / self.accum_count,
248+
"metrics/response_length": self.accum_response_length.item() / self.accum_count,
246249
"train/loss": self.accum_loss.item() / self.accum_count,
247-
"train/reward": self.accum_reward.item() / self.accum_count,
248-
"train/format_reward": self.accum_format_reward.item() / self.accum_count,
249-
"train/acc_reward": self.accum_acc_reward.item() / self.accum_count,
250250
"train/kl": self.accum_kl.item() / self.accum_count,
251251
"train/advantages": self.accum_advantages.item() / self.accum_count,
252-
"train/response_length": self.accum_response_length.item() / self.accum_count,
252+
"train/learning_rate": self.lr_scheduler.get_last_lr()[0],
253+
"rollout/temperature": data["temperature"].cpu().numpy()[0][0],
253254
}
254255
)
255256
self.accum_loss.zero_()

applications/ColossalChat/coati/distributed/producer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,9 @@ def loop(self) -> None:
101101
break
102102
outputs = self.rollout(**batch)
103103
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")
104+
outputs["temperature"] = torch.tensor(
105+
[self.model.generate_config.temperature] * outputs["input_ids"].size(0)
106+
).to(outputs["input_ids"].device)
104107
outputs = pre_send(outputs)
105108
ray_broadcast_tensor_dict(
106109
outputs, src=0, device=self.device, group_name=f"sync_data_{self.producer_idx}"

0 commit comments

Comments
 (0)