Skip to content

Commit 53371c6

Browse files
author
Hossein Kavianihamedani
committed
Implement Epoch-Based Evaluation with Non-Blocking All-Reduce
1 parent a0f62e7 commit 53371c6

File tree

2 files changed

+542
-17
lines changed

2 files changed

+542
-17
lines changed

apps/sft/main.py

Lines changed: 105 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -279,8 +279,29 @@ def train_step(self, batch) -> None:
279279
self.optimizers.step()
280280
self.lr_schedulers.step()
281281

282+
def _extract_epoch_from_batch(self, batch: dict) -> int | None:
283+
"""Extract epoch number from batch metrics."""
284+
if "metrics" not in batch:
285+
return None
286+
287+
for metric in batch["metrics"]:
288+
if hasattr(metric, "metric_name") and metric.metric_name == "num_epochs":
289+
return metric.value
290+
return None
291+
282292
async def evaluate(self) -> dict[str, float]:
283-
"""Run evaluation on validation set (internal method, not an endpoint)."""
293+
"""Run evaluation on validation set for one complete epoch.
294+
295+
Uses prefetch + non-blocking all_reduce pattern to detect epoch completion
296+
across all ranks without blocking on every batch.
297+
298+
Pattern:
299+
- Iteration N: Start async all_reduce on next batch's epoch (non-blocking)
300+
- Process current batch while all_reduce completes in background
301+
- Iteration N+1: Check result from previous all_reduce (should be done)
302+
303+
This overlaps communication with computation for better performance.
304+
"""
284305
logger.info("=" * 50)
285306
logger.info("STARTING EVALUATION ")
286307
logger.info("=" * 50)
@@ -292,30 +313,97 @@ async def evaluate(self) -> dict[str, float]:
292313
val_dataloader = iter(self.val_dataloader)
293314
total_loss = 0.0
294315
num_batches = 0
316+
starting_epoch = None
317+
318+
# Prefetch first batch
319+
try:
320+
next_batch = next(val_dataloader)
321+
except StopIteration:
322+
logger.warning("Validation dataloader is empty")
323+
return {"val_loss": 0.0, "val_batches": 0}
324+
325+
next_should_break = False
326+
pending_work = None # Handle for async all_reduce
327+
epoch_tensor = None # Tensor for all_reduce result
295328

296329
with torch.no_grad():
297-
for step in range(self.eval_steps):
298-
try:
299-
batch = next(val_dataloader)
330+
while True:
331+
# Check result from PREVIOUS iteration's async all_reduce
332+
if pending_work is not None:
333+
pending_work.wait() # Should be complete (or very fast) since we did compute
334+
if epoch_tensor is not None:
335+
next_should_break = epoch_tensor.item() > 0
336+
pending_work = None
337+
338+
# Check if we should break (based on previous iteration's check)
339+
if next_should_break:
340+
logger.info(
341+
"Epoch completed across all ranks - stopping evaluation"
342+
)
343+
break
300344

301-
# Move tensors to device
302-
for k, v in batch.items():
303-
if isinstance(v, torch.Tensor):
304-
batch[k] = v.to(self.device)
345+
# Check optional cap on eval steps
346+
if self.eval_steps > 0 and num_batches >= self.eval_steps:
347+
logger.info(f"Reached eval_steps cap of {self.eval_steps}")
348+
break
305349

306-
labels = batch.pop("labels")
307-
loss = self.forward_only(batch, labels)
350+
# Use the batch that was prefetched in previous iteration
351+
batch = next_batch
308352

309-
total_loss += loss.item()
310-
num_batches += 1
353+
# Extract epoch from current batch
354+
current_epoch = self._extract_epoch_from_batch(batch)
355+
if current_epoch is not None and starting_epoch is None:
356+
starting_epoch = current_epoch
357+
logger.info(f"Starting evaluation at epoch {starting_epoch}")
311358

312-
logger.info(
313-
f" Eval batch {num_batches}/{self.eval_steps} | Loss: {loss.item():.4f}"
314-
)
359+
# Prefetch next batch and start async all_reduce
360+
try:
361+
next_batch = next(val_dataloader)
362+
363+
# Extract epoch from next batch
364+
next_epoch = self._extract_epoch_from_batch(next_batch)
365+
366+
# Start NON-BLOCKING all_reduce to check if any rank completed epoch
367+
if next_epoch is not None and starting_epoch is not None:
368+
# Check if next batch indicates epoch completion
369+
epoch_increment = next_epoch - starting_epoch
370+
371+
if torch.distributed.is_initialized():
372+
# Create tensor for all_reduce
373+
epoch_tensor = torch.tensor(
374+
[epoch_increment], dtype=torch.long, device=self.device
375+
)
376+
# Start async all_reduce (returns immediately, doesn't block)
377+
pending_work = torch.distributed.all_reduce(
378+
epoch_tensor,
379+
op=torch.distributed.ReduceOp.MAX,
380+
async_op=True, # NON-BLOCKING - returns immediately
381+
)
382+
else:
383+
# Single rank case - just check locally
384+
next_should_break = epoch_increment > 0
315385

316386
except StopIteration:
317-
logger.warning("Reached end of validation dataloader early")
318-
break
387+
# No more batches - this is the last one
388+
next_should_break = True
389+
390+
# Process current batch (while all_reduce completes in background)
391+
# Move tensors to device
392+
for k, v in batch.items():
393+
if isinstance(v, torch.Tensor):
394+
batch[k] = v.to(self.device)
395+
396+
labels = batch.pop("labels")
397+
loss = self.forward_only(batch, labels)
398+
# GPU compute happens here while network does all_reduce
399+
400+
total_loss += loss.item()
401+
num_batches += 1
402+
403+
eval_steps_info = f"/{self.eval_steps}" if self.eval_steps > 0 else ""
404+
logger.info(
405+
f" Eval batch {num_batches}{eval_steps_info} | Loss: {loss.item():.4f}"
406+
)
319407

320408
# Set model back to train mode
321409
for model_part in self.model_parts:

0 commit comments

Comments
 (0)