1
1
import copy
2
+ import json
2
3
import os
3
4
from typing import Any , Dict , Optional
4
5
@@ -49,6 +50,8 @@ def __init__(
49
50
project_name : str = None ,
50
51
run_name : str = None ,
51
52
wandb_group_name : str = None ,
53
+ log_rollout_interval : int = 20 ,
54
+ rollout_log_file : str = "./rollout_log.jsonl" ,
52
55
):
53
56
self .producer_idx = producer_idx
54
57
self .num_producers = num_producers
@@ -58,7 +61,7 @@ def __init__(
58
61
self .microbatch_size = microbatch_size
59
62
assert batch_size % microbatch_size == 0
60
63
self .num_microbatches = batch_size // microbatch_size
61
- self .lastest_eval_step = - 1
64
+ self .latest_eval_step = - 1
62
65
63
66
self .train_dataset_config = train_dataset_config
64
67
self .model_config = model_config
@@ -68,6 +71,17 @@ def __init__(
68
71
self .eval_interval = eval_interval
69
72
self .eval_save_dir = eval_save_dir
70
73
self .consumer_global_step = 0
74
+ self .eval_mode = False
75
+ self .log_rollout_interval = log_rollout_interval
76
+ self .latest_rollout_log_step = - 1
77
+ if producer_idx == 0 :
78
+ if os .path .exists (rollout_log_file ):
79
+ raise ValueError (
80
+ f"Rollout log file { rollout_log_file } already exists. Please delete it or change the name."
81
+ )
82
+ else :
83
+ os .makedirs (os .path .dirname (rollout_log_file ), exist_ok = True )
84
+ self .rollout_log_file = open (rollout_log_file , "w" , encoding = "utf8" )
71
85
if self .producer_idx == 0 :
72
86
self .wandb_run = wandb .init (
73
87
project = project_name ,
@@ -77,7 +91,7 @@ def __init__(
77
91
group = wandb_group_name ,
78
92
)
79
93
80
- if os .path .exists (self .eval_save_dir ):
94
+ if os .path .exists (self .eval_save_dir ) and self . eval_interval > 0 :
81
95
raise ValueError (f"Eval save dir { self .eval_save_dir } already exists. Please delete it or change the name." )
82
96
83
97
# init tokenizer
@@ -180,14 +194,15 @@ def loop(self) -> None:
180
194
break
181
195
if self .eval_interval > 0 and self .eval_dataset_config is not None :
182
196
if (
183
- self .consumer_global_step % self .eval_interval == 0
184
- and self .consumer_global_step > self .lastest_eval_step
185
- ):
197
+ self .consumer_global_step - self .latest_eval_step >= self . eval_interval
198
+ and self .consumer_global_step > self .latest_eval_step
199
+ ) or self . latest_eval_step == - 1 :
186
200
to_log_msg = {}
201
+ self .eval_mode = True
187
202
for eval_task_name in self .eval_dataloaders :
188
203
if self .producer_idx == 0 :
189
204
print (
190
- f"[P{ self .producer_idx } ] Evaluate consumer step { self .consumer_global_step } on task { eval_task_name } "
205
+ f"[P{ self .producer_idx } ] Evaluate model at training step { self .consumer_global_step } on task { eval_task_name } "
191
206
)
192
207
eval_results = []
193
208
eval_statistics_tensor = torch .zeros ((2 ,), dtype = torch .float32 ).to (self .device )
@@ -220,14 +235,15 @@ def loop(self) -> None:
220
235
safe_append_to_jsonl_file (
221
236
os .path .join (
222
237
self .eval_save_dir ,
223
- f"{ eval_task_name } _episode_ { episode } _step_ { self .consumer_global_step } .jsonl" ,
238
+ f"{ eval_task_name } _training_step_ { self .consumer_global_step } .jsonl" ,
224
239
),
225
240
eval_results ,
226
241
)
227
242
228
243
if self .producer_idx == 0 :
229
244
self .wandb_run .log (to_log_msg , step = self .consumer_global_step )
230
- self .lastest_eval_step = self .consumer_global_step
245
+ self .eval_mode = False
246
+ self .latest_eval_step = self .consumer_global_step
231
247
outputs = self .rollout (** batch )
232
248
233
249
print (f"[P{ self .producer_idx } ] Send data { [(k , v .shape ) for k , v in outputs .items ()]} " )
@@ -256,6 +272,8 @@ def loop(self) -> None:
256
272
state_dict = ray_broadcast_tensor_dict (
257
273
None , self .num_producers , device = self .device , group_name = f"sync_model_{ pp_idx } "
258
274
)
275
+ if "consumer_global_step" in state_dict :
276
+ self .consumer_global_step = state_dict .pop ("consumer_global_step" ).item ()
259
277
self .load_state_dict (state_dict )
260
278
else :
261
279
print (
@@ -311,6 +329,8 @@ def __init__(
311
329
project_name : str = None ,
312
330
run_name : str = None ,
313
331
wandb_group_name : str = None ,
332
+ log_rollout_interval : int = 20 ,
333
+ rollout_log_file : str = "./rollout_log.jsonl" ,
314
334
):
315
335
super ().__init__ (
316
336
producer_idx ,
@@ -333,6 +353,8 @@ def __init__(
333
353
project_name = project_name ,
334
354
run_name = run_name ,
335
355
wandb_group_name = wandb_group_name ,
356
+ log_rollout_interval = log_rollout_interval ,
357
+ rollout_log_file = rollout_log_file ,
336
358
)
337
359
self .model = self .backend_cls (model_config , generate_config , self .tokenizer , num_generations )
338
360
self .eval_generation_config = copy .deepcopy (self .model .generate_config )
@@ -343,10 +365,32 @@ def __init__(
343
365
@torch .no_grad ()
344
366
def rollout (self , input_ids , attention_mask , ** kwargs ):
345
367
rollouts = self .model .generate (input_ids , attention_mask , ** kwargs )
346
- # if self.producer_idx == 1:
347
- # print("Rollout example:\n", self.tokenizer.decode(rollouts["input_ids"][0][0], skip_special_tokens=True))
348
-
368
+ if self .producer_idx == 0 and not self .eval_mode :
369
+ if (
370
+ self .consumer_global_step - self .latest_rollout_log_step >= self .log_rollout_interval
371
+ or self .latest_rollout_log_step == - 1
372
+ ):
373
+ new_record = (
374
+ json .dumps (
375
+ {
376
+ "train_step" : self .consumer_global_step ,
377
+ "rollout" : self .tokenizer .batch_decode (
378
+ rollouts ["input_ids" ][:, 0 ], skip_special_tokens = True
379
+ ),
380
+ }
381
+ )
382
+ + "\n "
383
+ )
384
+ self .rollout_log_file .write (new_record )
385
+ self .rollout_log_file .flush ()
386
+ self .latest_rollout_log_step = self .consumer_global_step
349
387
return rollouts
350
388
389
+ def __del__ (self ):
390
+ if self .producer_idx == 0 :
391
+ self .wandb_run .finish ()
392
+ if hasattr (self , "rollout_log_file" ):
393
+ self .rollout_log_file .close ()
394
+
351
395
def load_state_dict (self , state_dict ):
352
396
self .model .load_state_dict (state_dict )
0 commit comments