@@ -321,6 +321,8 @@ def cyclic_iter(iterable):
321321 yield x
322322
323323 self .resample_iterator = cyclic_iter (self .get_resample_dataloader ())
324+ # flag indicating whether the evaluation has started
325+ self .eval_flag = False
324326
325327 def split_batches (self ):
326328 """Sync weights in batches
@@ -1089,6 +1091,10 @@ def _get_per_token_logps(self, model, inputs):
10891091 return selective_log_softmax (logits , input_ids ) # compute logprobs for the input tokens
10901092
10911093 def evaluation_loop (self , dataloader , * args , ** kwargs ):
1094+ # Wait for the training rollout to complete
1095+ if self .args .async_generate :
1096+ while not self .is_async_generate_eval_rollout_done ():
1097+ time .sleep (0.1 )
10921098 # set mini_batch_size None in evaluation
10931099 mini_batch_size = self .args .mini_batch_size
10941100 self .args .mini_batch_size = None
@@ -1099,13 +1105,17 @@ def evaluation_loop(self, dataloader, *args, **kwargs):
10991105 metrics = {f'{ metric_key_prefix } _{ key } ' : sum (val ) / len (val ) for key , val in self ._metrics ['eval' ].items ()}
11001106 output .metrics .update (metrics )
11011107 self .args .mini_batch_size = mini_batch_size
1108+ self .eval_flag = True
11021109 return output
11031110
11041111 def training_step (self ,
11051112 model : nn .Module ,
11061113 inputs : Dict [str , Union [torch .Tensor , Any ]],
11071114 num_items_in_batch = None ) -> torch .Tensor :
1108-
1115+ if self .args .async_generate :
1116+ # Wait for the eval rollout to complete
1117+ while not self .is_async_generate_eval_rollout_done ():
1118+ time .sleep (0.1 )
11091119 if self .args .mini_batch_size is None :
11101120 return super ().training_step (model , inputs , num_items_in_batch )
11111121 model .train ()
@@ -1326,3 +1336,9 @@ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> Non
13261336 if self .args .wandb_log_unique_prompts :
13271337 df = df .drop_duplicates (subset = ['prompt' ])
13281338 wandb .log ({'completions' : wandb .Table (dataframe = df )})
1339+
1340+ def is_async_generate_eval_rollout_done (self ):
1341+ return not self .eval_flag or not self .eval_queue .empty ()
1342+
1343+ def is_async_generate_train_rollout_done (self ):
1344+ return not self .train_queue .empty ()
0 commit comments