@@ -58,6 +58,7 @@ def __init__(
58
58
self .microbatch_size = microbatch_size
59
59
assert batch_size % microbatch_size == 0
60
60
self .num_microbatches = batch_size // microbatch_size
61
+ self .lastest_eval_step = - 1
61
62
62
63
self .train_dataset_config = train_dataset_config
63
64
self .model_config = model_config
@@ -178,12 +179,15 @@ def loop(self) -> None:
178
179
if i >= num_valid_microbatches :
179
180
break
180
181
if self .eval_interval > 0 and self .eval_dataset_config is not None :
181
- if i % self .eval_interval == 0 :
182
+ if (
183
+ self .consumer_global_step % self .eval_interval == 0
184
+ and self .consumer_global_step > self .lastest_eval_step
185
+ ):
182
186
to_log_msg = {}
183
187
for eval_task_name in self .eval_dataloaders :
184
188
if self .producer_idx == 0 :
185
189
print (
186
- f"[P{ self .producer_idx } ] Evaluate episode { episode } step { i } on task { eval_task_name } "
190
+ f"[P{ self .producer_idx } ] Evaluate episode { episode } step { self . consumer_global_step } on task { eval_task_name } "
187
191
)
188
192
eval_results = []
189
193
eval_statistics_tensor = torch .zeros ((2 ,), dtype = torch .float32 ).to (self .device )
@@ -223,6 +227,7 @@ def loop(self) -> None:
223
227
224
228
if self .producer_idx == 0 :
225
229
self .wandb_run .log (to_log_msg , step = self .consumer_global_step )
230
+ self .lastest_eval_step = self .consumer_global_step
226
231
outputs = self .rollout (** batch )
227
232
228
233
print (f"[P{ self .producer_idx } ] Send data { [(k , v .shape ) for k , v in outputs .items ()]} " )
0 commit comments