1
1
import copy
2
+ import json
2
3
import os
3
4
from typing import Any , Dict , Optional
4
5
@@ -49,7 +50,8 @@ def __init__(
49
50
project_name : str = None ,
50
51
run_name : str = None ,
51
52
wandb_group_name : str = None ,
52
- wandb_log_rollout_interval : int = 20 ,
53
+ log_rollout_interval : int = 20 ,
54
+ rollout_log_file : str = "./rollout_log.jsonl" ,
53
55
):
54
56
self .producer_idx = producer_idx
55
57
self .num_producers = num_producers
@@ -70,9 +72,16 @@ def __init__(
70
72
self .eval_save_dir = eval_save_dir
71
73
self .consumer_global_step = 0
72
74
self .eval_mode = False
73
- self .wandb_rollout_data = []
74
- self .wandb_log_rollout_interval = wandb_log_rollout_interval
75
+ self .log_rollout_interval = log_rollout_interval
75
76
self .latest_rollout_log_step = - 1
77
+ if producer_idx == 0 :
78
+ if os .path .exists (rollout_log_file ):
79
+ raise ValueError (
80
+ f"Rollout log file { rollout_log_file } already exists. Please delete it or change the name."
81
+ )
82
+ else :
83
+ os .makedirs (os .path .dirname (rollout_log_file ), exist_ok = True )
84
+ self .rollout_log_file = open (rollout_log_file , "w" , encoding = "utf8" )
76
85
if self .producer_idx == 0 :
77
86
self .wandb_run = wandb .init (
78
87
project = project_name ,
@@ -320,6 +329,8 @@ def __init__(
320
329
project_name : str = None ,
321
330
run_name : str = None ,
322
331
wandb_group_name : str = None ,
332
+ log_rollout_interval : int = 20 ,
333
+ rollout_log_file : str = "./rollout_log.jsonl" ,
323
334
):
324
335
super ().__init__ (
325
336
producer_idx ,
@@ -342,6 +353,8 @@ def __init__(
342
353
project_name = project_name ,
343
354
run_name = run_name ,
344
355
wandb_group_name = wandb_group_name ,
356
+ log_rollout_interval = log_rollout_interval ,
357
+ rollout_log_file = rollout_log_file ,
345
358
)
346
359
self .model = self .backend_cls (model_config , generate_config , self .tokenizer , num_generations )
347
360
self .eval_generation_config = copy .deepcopy (self .model .generate_config )
@@ -353,26 +366,31 @@ def __init__(
353
366
def rollout (self , input_ids , attention_mask , ** kwargs ):
354
367
rollouts = self .model .generate (input_ids , attention_mask , ** kwargs )
355
368
if self .producer_idx == 0 and not self .eval_mode :
356
- wandb_rollout_data = self .wandb_rollout_data + [
357
- [
358
- str (self .consumer_global_step ),
359
- str (self .tokenizer .decode (rollouts ["input_ids" ][0 ][0 ], skip_special_tokens = True )),
360
- ]
361
- ]
362
369
if (
363
- self .consumer_global_step - self .latest_rollout_log_step >= self .wandb_log_rollout_interval
370
+ self .consumer_global_step - self .latest_rollout_log_step >= self .log_rollout_interval
364
371
or self .latest_rollout_log_step == - 1
365
372
):
366
- self .wandb_rollout_data = wandb_rollout_data
367
- self .latest_rollout_log_step = self .consumer_global_step
368
- self .wandb_run .log (
369
- {
370
- "rollout/rollout_examples" : wandb .Table (
371
- columns = ["train_step" , "rollout_examples" ], data = wandb_rollout_data
373
+ new_record = (
374
+ json .dumps (
375
+ {
376
+ "train_step" : self .consumer_global_step ,
377
+ "rollout" : self .tokenizer .batch_decode (
378
+ rollouts ["input_ids" ][:, 0 ], skip_special_tokens = True
379
+ ),
380
+ }
372
381
)
373
- }
374
- )
382
+ + "\n "
383
+ )
384
+ self .rollout_log_file .write (new_record )
385
+ self .rollout_log_file .flush ()
386
+ self .latest_rollout_log_step = self .consumer_global_step
375
387
return rollouts
376
388
389
+ def __del__ (self ):
390
+ if self .producer_idx == 0 :
391
+ self .wandb_run .finish ()
392
+ if hasattr (self , "rollout_log_file" ):
393
+ self .rollout_log_file .close ()
394
+
377
395
def load_state_dict (self , state_dict ):
378
396
self .model .load_state_dict (state_dict )
0 commit comments