@@ -279,8 +279,29 @@ def train_step(self, batch) -> None:
279279 self .optimizers .step ()
280280 self .lr_schedulers .step ()
281281
282+ def _extract_epoch_from_batch (self , batch : dict ) -> int | None :
283+ """Extract epoch number from batch metrics."""
284+ if "metrics" not in batch :
285+ return None
286+
287+ for metric in batch ["metrics" ]:
288+ if hasattr (metric , "metric_name" ) and metric .metric_name == "num_epochs" :
289+ return metric .value
290+ return None
291+
282292 async def evaluate (self ) -> dict [str , float ]:
283- """Run evaluation on validation set (internal method, not an endpoint)."""
293+ """Run evaluation on validation set for one complete epoch.
294+
295+ Uses prefetch + non-blocking all_reduce pattern to detect epoch completion
296+ across all ranks without blocking on every batch.
297+
298+ Pattern:
299+ - Iteration N: Start async all_reduce on next batch's epoch (non-blocking)
300+ - Process current batch while all_reduce completes in background
301+ - Iteration N+1: Check result from previous all_reduce (should be done)
302+
303+ This overlaps communication with computation for better performance.
304+ """
284305 logger .info ("=" * 50 )
285306 logger .info ("STARTING EVALUATION " )
286307 logger .info ("=" * 50 )
@@ -292,30 +313,97 @@ async def evaluate(self) -> dict[str, float]:
292313 val_dataloader = iter (self .val_dataloader )
293314 total_loss = 0.0
294315 num_batches = 0
316+ starting_epoch = None
317+
318+ # Prefetch first batch
319+ try :
320+ next_batch = next (val_dataloader )
321+ except StopIteration :
322+ logger .warning ("Validation dataloader is empty" )
323+ return {"val_loss" : 0.0 , "val_batches" : 0 }
324+
325+ next_should_break = False
326+ pending_work = None # Handle for async all_reduce
327+ epoch_tensor = None # Tensor for all_reduce result
295328
296329 with torch .no_grad ():
297- for step in range (self .eval_steps ):
298- try :
299- batch = next (val_dataloader )
330+ while True :
331+ # Check result from PREVIOUS iteration's async all_reduce
332+ if pending_work is not None :
333+ pending_work .wait () # Should be complete (or very fast) since we did compute
334+ if epoch_tensor is not None :
335+ next_should_break = epoch_tensor .item () > 0
336+ pending_work = None
337+
338+ # Check if we should break (based on previous iteration's check)
339+ if next_should_break :
340+ logger .info (
341+ "Epoch completed across all ranks - stopping evaluation"
342+ )
343+ break
300344
301- # Move tensors to device
302- for k , v in batch . items () :
303- if isinstance ( v , torch . Tensor ):
304- batch [ k ] = v . to ( self . device )
345+ # Check optional cap on eval steps
346+ if self . eval_steps > 0 and num_batches >= self . eval_steps :
347+ logger . info ( f"Reached eval_steps cap of { self . eval_steps } " )
348+ break
305349
306- labels = batch . pop ( "labels" )
307- loss = self . forward_only ( batch , labels )
350+ # Use the batch that was prefetched in previous iteration
351+ batch = next_batch
308352
309- total_loss += loss .item ()
310- num_batches += 1
353+ # Extract epoch from current batch
354+ current_epoch = self ._extract_epoch_from_batch (batch )
355+ if current_epoch is not None and starting_epoch is None :
356+ starting_epoch = current_epoch
357+ logger .info (f"Starting evaluation at epoch { starting_epoch } " )
311358
312- logger .info (
313- f" Eval batch { num_batches } /{ self .eval_steps } | Loss: { loss .item ():.4f} "
314- )
359+ # Prefetch next batch and start async all_reduce
360+ try :
361+ next_batch = next (val_dataloader )
362+
363+ # Extract epoch from next batch
364+ next_epoch = self ._extract_epoch_from_batch (next_batch )
365+
366+ # Start NON-BLOCKING all_reduce to check if any rank completed epoch
367+ if next_epoch is not None and starting_epoch is not None :
368+ # Check if next batch indicates epoch completion
369+ epoch_increment = next_epoch - starting_epoch
370+
371+ if torch .distributed .is_initialized ():
372+ # Create tensor for all_reduce
373+ epoch_tensor = torch .tensor (
374+ [epoch_increment ], dtype = torch .long , device = self .device
375+ )
376+ # Start async all_reduce (returns immediately, doesn't block)
377+ pending_work = torch .distributed .all_reduce (
378+ epoch_tensor ,
379+ op = torch .distributed .ReduceOp .MAX ,
380+ async_op = True , # NON-BLOCKING - returns immediately
381+ )
382+ else :
383+ # Single rank case - just check locally
384+ next_should_break = epoch_increment > 0
315385
316386 except StopIteration :
317- logger .warning ("Reached end of validation dataloader early" )
318- break
387+ # No more batches - this is the last one
388+ next_should_break = True
389+
390+ # Process current batch (while all_reduce completes in background)
391+ # Move tensors to device
392+ for k , v in batch .items ():
393+ if isinstance (v , torch .Tensor ):
394+ batch [k ] = v .to (self .device )
395+
396+ labels = batch .pop ("labels" )
397+ loss = self .forward_only (batch , labels )
398+ # GPU compute happens here while network does all_reduce
399+
400+ total_loss += loss .item ()
401+ num_batches += 1
402+
403+ eval_steps_info = f"/{ self .eval_steps } " if self .eval_steps > 0 else ""
404+ logger .info (
405+ f" Eval batch { num_batches } { eval_steps_info } | Loss: { loss .item ():.4f} "
406+ )
319407
320408 # Set model back to train mode
321409 for model_part in self .model_parts :
0 commit comments