Skip to content

Conversation

HosseinKaviani-H
Copy link

@HosseinKaviani-H HosseinKaviani-H commented Oct 14, 2025

Add periodic evaluation during training with epoch-aware synchronization

Added evaluation functionality to the SFT training recipe with proper multi-rank synchronization and epoch completion detection.

Changes:

Core Evaluation Features

  • Configurable evaluation interval: Added eval_interval and eval_steps parameters to control when and how much to evaluate

    • eval_interval: Number of training steps between evaluations (defaults to float('inf') to disable eval when not configured)
    • eval_steps: Number of validation batches to evaluate per evaluation run (defaults to 0 for unlimited - runs one full epoch)
  • Validation dataloader: Set up separate validation dataloader using the last 10% of the train split

  • Forward-only pass: Implemented forward_only() method for evaluation without gradient computation, supporting both pipeline parallel and non-PP configurations

Epoch-Aware Evaluation with Multi-Rank Synchronization

  • Epoch completion detection: Evaluates for exactly one complete epoch by monitoring batch["metrics"] for epoch increments

    • Extracts num_epochs from batch metadata to detect when validation dataset completes one full pass
    • Prevents evaluation from running forever on infinite streaming dataloaders
  • Non-blocking all_reduce pattern: Synchronizes epoch completion across all ranks without blocking computation

    • Prefetch + async pattern: Fetches next batch while processing current batch
    • Overlapped communication: Starts async_op=True all_reduce on next batch's epoch while GPU computes current batch's loss
    • Deferred checking: Checks all_reduce result in next iteration (should be complete by then)
    • Early stopping: All ranks stop when ANY rank completes an epoch (via MAX reduction)

Integration

  • Training integration: Evaluation runs automatically every eval_interval steps during training
  • Graceful degradation: If eval_steps > 0, it acts as a cap (useful for quick validation checks or when epoch metadata is unavailable)

Usage:

Configure in your YAML config file:

training:
  eval_interval: 100  # Evaluate every 100 training steps
  eval_steps: 0       # Run one complete epoch (recommended)
  # eval_steps: 50    # Alternative: cap at 50 batches for faster checks

If eval_intervaland eval_steps are not set, evaluation is automatically disabled.

Testing:
Comprehensive test suite (test_evaluate.py) validates:

✅ Epoch extraction from batch metadata
✅ Single epoch completion detection
✅ eval_steps cap enforcement
✅ Empty/single batch edge cases
✅ Async all_reduce pattern behavior
✅ Multi-rank synchronization logic
✅ Prefetch pattern correctness
All 14 tests pass successfully.

Algorithm Details:
The non-blocking evaluation loop follows this pattern:

Iteration N:

  1. Check if previous all_reduce says we should stop
  2. Process current batch (forward pass, compute loss)
  3. Prefetch next batch
  4. Extract epoch from next batch
  5. Start async all_reduce on epoch_increment (returns immediately, doesn't block)

Iteration N+1:

  1. Wait for all_reduce from iteration N (should be done, or very fast)
  2. Check result: if any rank has epoch_increment > 0, stop
  3. Process batch N+1...
    This overlaps network communication with GPU computation for better performance, while ensuring all ranks stop at the same point.

This updated description captures:

  1. The epoch detection mechanism
  2. The async all_reduce implementation details
  3. The performance benefits of the overlapped communication pattern
  4. Test coverage
  5. Clear usage examples
  6. Algorithm explanation for reviewers

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Oct 14, 2025
@felipemello1
Copy link
Contributor

hey @HosseinKaviani-H , thanks for opening the PR. Its a bit tricky to run the validation, because the dataset is infinite. So it doesnt know when to stop. You can retrieve the epoch number for each dataset from batch["metrics'], but we haven't looked into that. On top of that, if you have multiple datasets, they will epoch at different paces. I think that there are a few ways on handling this:

  1. changing the dataset definition for validation / adding some flag / some logic that goes in the dataset
  2. Trying to orchestrate it in the recipe, i.e. checking if the dataset has epoched and stopping then.

It seems that you defined "eval_steps" as a clever solution to not deal with any of that. But i wonder about correctness here, i.e. not going through the entire eval, or going 1.5x times, for example. Any thoughts?

@HosseinKaviani-H
Copy link
Author

HosseinKaviani-H commented Oct 14, 2025

hey @HosseinKaviani-H , thanks for opening the PR. Its a bit tricky to run the validation, because the dataset is infinite. So it doesnt know when to stop. You can retrieve the epoch number for each dataset from batch["metrics'], but we haven't looked into that. On top of that, if you have multiple datasets, they will epoch at different paces. I think that there are a few ways on handling this:

  1. changing the dataset definition for validation / adding some flag / some logic that goes in the dataset
  2. Trying to orchestrate it in the recipe, i.e. checking if the dataset has epoched and stopping then.

It seems that you defined "eval_steps" as a clever solution to not deal with any of that. But i wonder about correctness here, i.e. not going through the entire eval, or going 1.5x times, for example. Any thoughts?

Hi @felipemello1 ,

Thanks for your comment. Yeah I think one good solution as you mentioned is to retrieve the epoch number in the training loop and once it hits 0 to 1, it breaks. I'll try to give it some thoughts and implement it.

And yes, counting batches is arbitrary here as if eval_steps is too low it could lead to incomplete evaluation or too high it might cause duplicate evaluation. Hence, checking epoch number sounds a better solution here.

@ebsmothers
Copy link
Contributor

Leaving this comment here before a full review since I think it's relevant to the point raised by @felipemello1: previously @DNXie observed challenges with iterable dataset hanging when there are uneven numbers of samples across the ranks. In general this is a pretty hard problem to solve cleanly. So actually I would recommend going with the approach of using a fixed number of steps. You can see the full context in this torchtitan issue: pytorch/torchtitan#1618

@felipemello1
Copy link
Contributor

dataset hanging when there are uneven numbers of samples across the ranks.

@ebsmothers this should never happen to us, since we have inifinite datasets. Thats one of the main args for infinite iterable: you dont have to worry about hanging issues. It just restarts the iter and keeps providing new samples.

@ebsmothers
Copy link
Contributor

@felipemello1 sorry maybe I don't fully understand your suggestions then. What is the termination condition for the validation loop? If it is epoch-based in any way I think we will run into this issue, right?

@felipemello1
Copy link
Contributor

felipemello1 commented Oct 14, 2025

@felipemello1 sorry maybe I don't fully understand your suggestions then. What is the termination condition for the validation loop? If it is epoch-based in any way I think we will run into this issue, right?

we can identify the change in epoch and drop last using all_gather + barrier. Dummy example for single dataset:

for batch in dataloader:
      epoch = batch["metrics"]["epoch"]
      is_new_epoch = epoch - prev_epoch

	  # if any rank has a new epoch, we stop, i.e. drop last batch
      is_new_epoch = all_gather(is_new_epoch)
      torch.distributed.barrier()
      if is_new_epoch>0:
            break

In the example above, for bsz=4, maybe rank_0 would have 2 samples from epoch 0 and 2 from epoch 1. But the batch size would always be 4. It would never hang.


Maybe this could be done elegantly inside of the dataset and hide the logic from the recipe? but i dont think that there is a pretty way.


Also not sure how to handle the multidataset situation. Perhaps:

# iterate over one dataset at a time
for dataloader in dataloaders:
      for batch in dataloader:
            ...

does it make sense @ebsmothers ?

@ebsmothers
Copy link
Contributor

@felipemello1 that's an interesting idea. In addition to your point about it not being super pretty, I am also wary of the torch.distributed.barrier() usage. I understand why it's necessary here but blocking for all ranks on every single batch is not ideal imo

@felipemello1
Copy link
Contributor

felipemello1 commented Oct 14, 2025

@felipemello1 that's an interesting idea. In addition to your point about it not being super pretty, I am also wary of the torch.distributed.barrier() usage. I understand why it's necessary here but blocking for all ranks on every single batch is not ideal imo

We could add to the ugliness and prefetch + check epoch change on a different stream one epoch in advance, so it would be non blocking. This can be an utility and removed from the recipe. It would also only happen for validation (training is safe).

@HosseinKaviani-H
Copy link
Author

Add periodic evaluation during training with epoch-aware synchronization

Added evaluation functionality to the SFT training recipe with proper multi-rank synchronization and epoch completion detection.

Changes:

Core Evaluation Features

  • Configurable evaluation interval: Added eval_interval and eval_steps parameters to control when and how much to evaluate

    • eval_interval: Number of training steps between evaluations (defaults to float('inf') to disable eval when not configured)
    • eval_steps: Number of validation batches to evaluate per evaluation run (defaults to 0 for unlimited - runs one full epoch)
  • Validation dataloader: Set up separate validation dataloader using the last 10% of the train split

  • Forward-only pass: Implemented forward_only() method for evaluation without gradient computation, supporting both pipeline parallel and non-PP configurations

Epoch-Aware Evaluation with Multi-Rank Synchronization

  • Epoch completion detection: Evaluates for exactly one complete epoch by monitoring batch["metrics"] for epoch increments

    • Extracts num_epochs from batch metadata to detect when validation dataset completes one full pass
    • Prevents evaluation from running forever on infinite streaming dataloaders
  • Non-blocking all_reduce pattern: Synchronizes epoch completion across all ranks without blocking computation

    • Prefetch + async pattern: Fetches next batch while processing current batch
    • Overlapped communication: Starts async_op=True all_reduce on next batch's epoch while GPU computes current batch's loss
    • Deferred checking: Checks all_reduce result in next iteration (should be complete by then)
    • Early stopping: All ranks stop when ANY rank completes an epoch (via MAX reduction)

Integration

  • Training integration: Evaluation runs automatically every eval_interval steps during training
  • Graceful degradation: If eval_steps > 0, it acts as a cap (useful for quick validation checks or when epoch metadata is unavailable)

Usage:

Configure in your YAML config file:

training:
  eval_interval: 100  # Evaluate every 100 training steps
  eval_steps: 0       # Run one complete epoch (recommended)
  # eval_steps: 50    # Alternative: cap at 50 batches for faster checks

If eval_intervaland eval_steps are not set, evaluation is automatically disabled.

Testing: Comprehensive test suite (test_evaluate.py) validates:

✅ Epoch extraction from batch metadata ✅ Single epoch completion detection ✅ eval_steps cap enforcement ✅ Empty/single batch edge cases ✅ Async all_reduce pattern behavior ✅ Multi-rank synchronization logic ✅ Prefetch pattern correctness All 14 tests pass successfully.

Algorithm Details: The non-blocking evaluation loop follows this pattern:

Iteration N:

  1. Check if previous all_reduce says we should stop
  2. Process current batch (forward pass, compute loss)
  3. Prefetch next batch
  4. Extract epoch from next batch
  5. Start async all_reduce on epoch_increment (returns immediately, doesn't block)

Iteration N+1:

  1. Wait for all_reduce from iteration N (should be done, or very fast)
  2. Check result: if any rank has epoch_increment > 0, stop
  3. Process batch N+1...
    This overlaps network communication with GPU computation for better performance, while ensuring all ranks stop at the same point.

This updated description captures:

  1. The epoch detection mechanism
  2. The async all_reduce implementation details
  3. The performance benefits of the overlapped communication pattern
  4. Test coverage
  5. Clear usage examples
  6. Algorithm explanation for reviewers

@ebsmothers @felipemello1 Given our discussion and per Felipe's idea, I have implemented an epoch-based eval with non-blocking all-reduce. I have updated the description and added a test_evaluate script to cover different scenarios.

@codecov-commenter
Copy link

Codecov Report

✅ All modified and coverable lines are covered by tests.
⚠️ Please upload report for BASE (main@399b20d). Learn more about missing BASE report.

Additional details and impacted files
@@           Coverage Diff           @@
##             main     #404   +/-   ##
=======================================
  Coverage        ?   73.43%           
=======================================
  Files           ?       81           
  Lines           ?     7829           
  Branches        ?        0           
=======================================
  Hits            ?     5749           
  Misses          ?     2080           
  Partials        ?        0           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@felipemello1
Copy link
Contributor

felipemello1 commented Oct 17, 2025

hey Hossein, thanks! I think that the tests are just mocking distributed and not testing it. @ebsmothers , do we have a decorator for distributed tests in forge? Regarding the implementation, i dont think we need >100 lines to do the sampling + epoch checking. Probably we can shrink it a bit


dataset = sft_iterable_dataset(
model_transform=tokenizer,
message_transform=AlpacaToMessages(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally we shouldn't hardcode this either (but it's a bit more work without instantiate)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. We can have this implemented as soon as we fix the main eval functionality

@HosseinKaviani-H
Copy link
Author

hey Hossein, thanks! I think that the tests are just mocking distributed and not testing it. @ebsmothers , do we have a decorator for distributed tests in forge? Regarding the implementation, i dont think we need >100 lines to do the sampling + epoch checking. Probably we can shrink it a bit

@felipemello1 I have shortened the code a bit. Let me know if the distributed testing so I can have that implemented as well

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants