-
Notifications
You must be signed in to change notification settings - Fork 16
Adding eval to the SFT #404
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
hey @HosseinKaviani-H , thanks for opening the PR. Its a bit tricky to run the validation, because the dataset is infinite. So it doesnt know when to stop. You can retrieve the epoch number for each dataset from batch["metrics'], but we haven't looked into that. On top of that, if you have multiple datasets, they will epoch at different paces. I think that there are a few ways on handling this:
It seems that you defined "eval_steps" as a clever solution to not deal with any of that. But i wonder about correctness here, i.e. not going through the entire eval, or going 1.5x times, for example. Any thoughts? |
Hi @felipemello1 , Thanks for your comment. Yeah I think one good solution as you mentioned is to retrieve the epoch number in the training loop and once it hits 0 to 1, it breaks. I'll try to give it some thoughts and implement it. And yes, counting batches is arbitrary here as if eval_steps is too low it could lead to incomplete evaluation or too high it might cause duplicate evaluation. Hence, checking epoch number sounds a better solution here. |
Leaving this comment here before a full review since I think it's relevant to the point raised by @felipemello1: previously @DNXie observed challenges with iterable dataset hanging when there are uneven numbers of samples across the ranks. In general this is a pretty hard problem to solve cleanly. So actually I would recommend going with the approach of using a fixed number of steps. You can see the full context in this torchtitan issue: pytorch/torchtitan#1618 |
@ebsmothers this should never happen to us, since we have inifinite datasets. Thats one of the main args for infinite iterable: you dont have to worry about hanging issues. It just restarts the iter and keeps providing new samples. |
@felipemello1 sorry maybe I don't fully understand your suggestions then. What is the termination condition for the validation loop? If it is epoch-based in any way I think we will run into this issue, right? |
we can identify the change in epoch and drop last using all_gather + barrier. Dummy example for single dataset:
In the example above, for bsz=4, maybe rank_0 would have 2 samples from epoch 0 and 2 from epoch 1. But the batch size would always be 4. It would never hang. Maybe this could be done elegantly inside of the dataset and hide the logic from the recipe? but i dont think that there is a pretty way. Also not sure how to handle the multidataset situation. Perhaps:
does it make sense @ebsmothers ? |
@felipemello1 that's an interesting idea. In addition to your point about it not being super pretty, I am also wary of the |
We could add to the ugliness and prefetch + check epoch change on a different stream one epoch in advance, so it would be non blocking. This can be an utility and removed from the recipe. It would also only happen for validation (training is safe). |
@ebsmothers @felipemello1 Given our discussion and per Felipe's idea, I have implemented an epoch-based eval with non-blocking all-reduce. I have updated the description and added a test_evaluate script to cover different scenarios. |
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #404 +/- ##
=======================================
Coverage ? 73.43%
=======================================
Files ? 81
Lines ? 7829
Branches ? 0
=======================================
Hits ? 5749
Misses ? 2080
Partials ? 0 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
hey Hossein, thanks! I think that the tests are just mocking distributed and not testing it. @ebsmothers , do we have a decorator for distributed tests in forge? Regarding the implementation, i dont think we need >100 lines to do the sampling + epoch checking. Probably we can shrink it a bit |
|
||
dataset = sft_iterable_dataset( | ||
model_transform=tokenizer, | ||
message_transform=AlpacaToMessages(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally we shouldn't hardcode this either (but it's a bit more work without instantiate)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree. We can have this implemented as soon as we fix the main eval functionality
@felipemello1 I have shortened the code a bit. Let me know if the distributed testing so I can have that implemented as well |
Add periodic evaluation during training with epoch-aware synchronization
Added evaluation functionality to the SFT training recipe with proper multi-rank synchronization and epoch completion detection.
Changes:
Core Evaluation Features
Configurable evaluation interval: Added
eval_interval
andeval_steps
parameters to control when and how much to evaluateeval_interval
: Number of training steps between evaluations (defaults tofloat('inf')
to disable eval when not configured)eval_steps
: Number of validation batches to evaluate per evaluation run (defaults to0
for unlimited - runs one full epoch)Validation dataloader: Set up separate validation dataloader using the last 10% of the train split
Forward-only pass: Implemented
forward_only()
method for evaluation without gradient computation, supporting both pipeline parallel and non-PP configurationsEpoch-Aware Evaluation with Multi-Rank Synchronization
Epoch completion detection: Evaluates for exactly one complete epoch by monitoring
batch["metrics"]
for epoch incrementsnum_epochs
from batch metadata to detect when validation dataset completes one full passNon-blocking all_reduce pattern: Synchronizes epoch completion across all ranks without blocking computation
async_op=True
all_reduce on next batch's epoch while GPU computes current batch's lossIntegration
eval_interval
steps during trainingeval_steps > 0
, it acts as a cap (useful for quick validation checks or when epoch metadata is unavailable)Usage:
Configure in your YAML config file:
If eval_intervaland eval_steps are not set, evaluation is automatically disabled.
Testing:
Comprehensive test suite (test_evaluate.py) validates:
✅ Epoch extraction from batch metadata
✅ Single epoch completion detection
✅ eval_steps cap enforcement
✅ Empty/single batch edge cases
✅ Async all_reduce pattern behavior
✅ Multi-rank synchronization logic
✅ Prefetch pattern correctness
All 14 tests pass successfully.
Algorithm Details:
The non-blocking evaluation loop follows this pattern:
Iteration N:
Iteration N+1:
This overlaps network communication with GPU computation for better performance, while ensuring all ranks stop at the same point.
This updated description captures: