Skip to content

Commit 021914c

Browse files
committed
support logging rollouts to wandb
1 parent 203dfb1 commit 021914c

File tree

1 file changed

+33
-9
lines changed

1 file changed

+33
-9
lines changed

applications/ColossalChat/coati/distributed/producer.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def __init__(
4949
project_name: str = None,
5050
run_name: str = None,
5151
wandb_group_name: str = None,
52+
wandb_log_rollout_interval: int = 20,
5253
):
5354
self.producer_idx = producer_idx
5455
self.num_producers = num_producers
@@ -58,7 +59,7 @@ def __init__(
5859
self.microbatch_size = microbatch_size
5960
assert batch_size % microbatch_size == 0
6061
self.num_microbatches = batch_size // microbatch_size
61-
self.lastest_eval_step = -1
62+
self.latest_eval_step = -1
6263

6364
self.train_dataset_config = train_dataset_config
6465
self.model_config = model_config
@@ -68,6 +69,10 @@ def __init__(
6869
self.eval_interval = eval_interval
6970
self.eval_save_dir = eval_save_dir
7071
self.consumer_global_step = 0
72+
self.eval_mode = False
73+
self.wandb_rollout_data = []
74+
self.wandb_log_rollout_interval = wandb_log_rollout_interval
75+
self.latest_rollout_log_step = -1
7176
if self.producer_idx == 0:
7277
self.wandb_run = wandb.init(
7378
project=project_name,
@@ -77,7 +82,7 @@ def __init__(
7782
group=wandb_group_name,
7883
)
7984

80-
if os.path.exists(self.eval_save_dir):
85+
if os.path.exists(self.eval_save_dir) and self.eval_interval > 0:
8186
raise ValueError(f"Eval save dir {self.eval_save_dir} already exists. Please delete it or change the name.")
8287

8388
# init tokenizer
@@ -180,10 +185,11 @@ def loop(self) -> None:
180185
break
181186
if self.eval_interval > 0 and self.eval_dataset_config is not None:
182187
if (
183-
self.consumer_global_step - self.lastest_eval_step >= self.eval_interval
184-
and self.consumer_global_step > self.lastest_eval_step
185-
):
188+
self.consumer_global_step - self.latest_eval_step >= self.eval_interval
189+
and self.consumer_global_step > self.latest_eval_step
190+
) or self.latest_eval_step == -1:
186191
to_log_msg = {}
192+
self.eval_mode = True
187193
for eval_task_name in self.eval_dataloaders:
188194
if self.producer_idx == 0:
189195
print(
@@ -227,7 +233,8 @@ def loop(self) -> None:
227233

228234
if self.producer_idx == 0:
229235
self.wandb_run.log(to_log_msg, step=self.consumer_global_step)
230-
self.lastest_eval_step = self.consumer_global_step
236+
self.eval_mode = False
237+
self.latest_eval_step = self.consumer_global_step
231238
outputs = self.rollout(**batch)
232239

233240
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")
@@ -345,9 +352,26 @@ def __init__(
345352
@torch.no_grad()
346353
def rollout(self, input_ids, attention_mask, **kwargs):
347354
rollouts = self.model.generate(input_ids, attention_mask, **kwargs)
348-
# if self.producer_idx == 1:
349-
# print("Rollout example:\n", self.tokenizer.decode(rollouts["input_ids"][0][0], skip_special_tokens=True))
350-
355+
if self.producer_idx == 0 and not self.eval_mode:
356+
wandb_rollout_data = self.wandb_rollout_data + [
357+
[
358+
str(self.consumer_global_step),
359+
str(self.tokenizer.decode(rollouts["input_ids"][0][0], skip_special_tokens=True)),
360+
]
361+
]
362+
if (
363+
self.consumer_global_step - self.latest_rollout_log_step >= self.wandb_log_rollout_interval
364+
or self.latest_rollout_log_step == -1
365+
):
366+
self.wandb_rollout_data = wandb_rollout_data
367+
self.latest_rollout_log_step = self.consumer_global_step
368+
self.wandb_run.log(
369+
{
370+
"rollout/rollout_examples": wandb.Table(
371+
columns=["train_step", "rollout_examples"], data=wandb_rollout_data
372+
)
373+
}
374+
)
351375
return rollouts
352376

353377
def load_state_dict(self, state_dict):

0 commit comments

Comments
 (0)