@@ -49,6 +49,7 @@ def __init__(
49
49
project_name : str = None ,
50
50
run_name : str = None ,
51
51
wandb_group_name : str = None ,
52
+ wandb_log_rollout_interval : int = 20 ,
52
53
):
53
54
self .producer_idx = producer_idx
54
55
self .num_producers = num_producers
@@ -58,7 +59,7 @@ def __init__(
58
59
self .microbatch_size = microbatch_size
59
60
assert batch_size % microbatch_size == 0
60
61
self .num_microbatches = batch_size // microbatch_size
61
- self .lastest_eval_step = - 1
62
+ self .latest_eval_step = - 1
62
63
63
64
self .train_dataset_config = train_dataset_config
64
65
self .model_config = model_config
@@ -68,6 +69,10 @@ def __init__(
68
69
self .eval_interval = eval_interval
69
70
self .eval_save_dir = eval_save_dir
70
71
self .consumer_global_step = 0
72
+ self .eval_mode = False
73
+ self .wandb_rollout_data = []
74
+ self .wandb_log_rollout_interval = wandb_log_rollout_interval
75
+ self .latest_rollout_log_step = - 1
71
76
if self .producer_idx == 0 :
72
77
self .wandb_run = wandb .init (
73
78
project = project_name ,
@@ -77,7 +82,7 @@ def __init__(
77
82
group = wandb_group_name ,
78
83
)
79
84
80
- if os .path .exists (self .eval_save_dir ):
85
+ if os .path .exists (self .eval_save_dir ) and self . eval_interval > 0 :
81
86
raise ValueError (f"Eval save dir { self .eval_save_dir } already exists. Please delete it or change the name." )
82
87
83
88
# init tokenizer
@@ -180,10 +185,11 @@ def loop(self) -> None:
180
185
break
181
186
if self .eval_interval > 0 and self .eval_dataset_config is not None :
182
187
if (
183
- self .consumer_global_step - self .lastest_eval_step >= self .eval_interval
184
- and self .consumer_global_step > self .lastest_eval_step
185
- ):
188
+ self .consumer_global_step - self .latest_eval_step >= self .eval_interval
189
+ and self .consumer_global_step > self .latest_eval_step
190
+ ) or self . latest_eval_step == - 1 :
186
191
to_log_msg = {}
192
+ self .eval_mode = True
187
193
for eval_task_name in self .eval_dataloaders :
188
194
if self .producer_idx == 0 :
189
195
print (
@@ -227,7 +233,8 @@ def loop(self) -> None:
227
233
228
234
if self .producer_idx == 0 :
229
235
self .wandb_run .log (to_log_msg , step = self .consumer_global_step )
230
- self .lastest_eval_step = self .consumer_global_step
236
+ self .eval_mode = False
237
+ self .latest_eval_step = self .consumer_global_step
231
238
outputs = self .rollout (** batch )
232
239
233
240
print (f"[P{ self .producer_idx } ] Send data { [(k , v .shape ) for k , v in outputs .items ()]} " )
@@ -345,9 +352,26 @@ def __init__(
345
352
@torch .no_grad ()
346
353
def rollout (self , input_ids , attention_mask , ** kwargs ):
347
354
rollouts = self .model .generate (input_ids , attention_mask , ** kwargs )
348
- # if self.producer_idx == 1:
349
- # print("Rollout example:\n", self.tokenizer.decode(rollouts["input_ids"][0][0], skip_special_tokens=True))
350
-
355
+ if self .producer_idx == 0 and not self .eval_mode :
356
+ wandb_rollout_data = self .wandb_rollout_data + [
357
+ [
358
+ str (self .consumer_global_step ),
359
+ str (self .tokenizer .decode (rollouts ["input_ids" ][0 ][0 ], skip_special_tokens = True )),
360
+ ]
361
+ ]
362
+ if (
363
+ self .consumer_global_step - self .latest_rollout_log_step >= self .wandb_log_rollout_interval
364
+ or self .latest_rollout_log_step == - 1
365
+ ):
366
+ self .wandb_rollout_data = wandb_rollout_data
367
+ self .latest_rollout_log_step = self .consumer_global_step
368
+ self .wandb_run .log (
369
+ {
370
+ "rollout/rollout_examples" : wandb .Table (
371
+ columns = ["train_step" , "rollout_examples" ], data = wandb_rollout_data
372
+ )
373
+ }
374
+ )
351
375
return rollouts
352
376
353
377
def load_state_dict (self , state_dict ):
0 commit comments