-
Notifications
You must be signed in to change notification settings - Fork 17
Adding eval to the SFT #404
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
hey @HosseinKaviani-H , thanks for opening the PR. Its a bit tricky to run the validation, because the dataset is infinite. So it doesnt know when to stop. You can retrieve the epoch number for each dataset from batch["metrics'], but we haven't looked into that. On top of that, if you have multiple datasets, they will epoch at different paces. I think that there are a few ways on handling this:
It seems that you defined "eval_steps" as a clever solution to not deal with any of that. But i wonder about correctness here, i.e. not going through the entire eval, or going 1.5x times, for example. Any thoughts? |
Hi @felipemello1 , Thanks for your comment. Yeah I think one good solution as you mentioned is to retrieve the epoch number in the training loop and once it hits 0 to 1, it breaks. I'll try to give it some thoughts and implement it. And yes, counting batches is arbitrary here as if eval_steps is too low it could lead to incomplete evaluation or too high it might cause duplicate evaluation. Hence, checking epoch number sounds a better solution here. |
Leaving this comment here before a full review since I think it's relevant to the point raised by @felipemello1: previously @DNXie observed challenges with iterable dataset hanging when there are uneven numbers of samples across the ranks. In general this is a pretty hard problem to solve cleanly. So actually I would recommend going with the approach of using a fixed number of steps. You can see the full context in this torchtitan issue: pytorch/torchtitan#1618 |
@ebsmothers this should never happen to us, since we have inifinite datasets. Thats one of the main args for infinite iterable: you dont have to worry about hanging issues. It just restarts the iter and keeps providing new samples. |
@felipemello1 sorry maybe I don't fully understand your suggestions then. What is the termination condition for the validation loop? If it is epoch-based in any way I think we will run into this issue, right? |
we can identify the change in epoch and drop last using all_gather + barrier. Dummy example for single dataset:
In the example above, for bsz=4, maybe rank_0 would have 2 samples from epoch 0 and 2 from epoch 1. But the batch size would always be 4. It would never hang. Maybe this could be done elegantly inside of the dataset and hide the logic from the recipe? but i dont think that there is a pretty way. Also not sure how to handle the multidataset situation. Perhaps:
does it make sense @ebsmothers ? |
@felipemello1 that's an interesting idea. In addition to your point about it not being super pretty, I am also wary of the |
We could add to the ugliness and prefetch + check epoch change on a different stream one epoch in advance, so it would be non blocking. This can be an utility and removed from the recipe. It would also only happen for validation (training is safe). |
@ebsmothers @felipemello1 Given our discussion and per Felipe's idea, I have implemented an epoch-based eval with non-blocking all-reduce. I have updated the description and added a test_evaluate script to cover different scenarios. |
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #404 +/- ##
=======================================
Coverage ? 73.43%
=======================================
Files ? 81
Lines ? 7829
Branches ? 0
=======================================
Hits ? 5749
Misses ? 2080
Partials ? 0 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
hey Hossein, thanks! I think that the tests are just mocking distributed and not testing it. @ebsmothers , do we have a decorator for distributed tests in forge? Regarding the implementation, i dont think we need >100 lines to do the sampling + epoch checking. Probably we can shrink it a bit |
|
||
dataset = sft_iterable_dataset( | ||
model_transform=tokenizer, | ||
message_transform=AlpacaToMessages(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally we shouldn't hardcode this either (but it's a bit more work without instantiate)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree. We can have this implemented as soon as we fix the main eval functionality
@felipemello1 I have shortened the code a bit. Let me know if the distributed testing so I can have that implemented as well |
# Prefetch first batch | ||
try: | ||
next_batch = next(val_dataloader) | ||
except StopIteration: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can remove the defensive checks and assume that the dataset is infinite. We have a class for it. I think you can just do an assertion that its TuneIterableDataset (we need to update the name and remove "tune", but dont worry about this on this PR:
torchforge/src/forge/data/datasets/dataset.py
Line 117 in d464193
class InfiniteTuneIterableDataset(TuneIterableDataset): |
Then we know its infinite, and we can remove try/except here and later in the loop. wdyt?
return None | ||
|
||
async def evaluate(self) -> dict[str, float]: | ||
"""Run evaluation with async all_reduce for cross-rank epoch synchronization.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it might be worth enhancing this docstring a bit. Maybe add a small numerical example.
if pending_work is not None: | ||
pending_work.wait() | ||
should_break = ( | ||
epoch_tensor.item() > 0 if epoch_tensor is not None else False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as a rule of thumb, its not good to use item(), because it requires cpu-gpu synchonization
next_batch = next(val_dataloader) | ||
next_epoch = self._extract_epoch_from_batch(next_batch) | ||
|
||
if next_epoch is not None and starting_epoch is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i am tempted to say lets remove the if/else check for None. Do you see a strong argument for keeping them? It would only be a problem if someone replaced our dataset abstraction, otherwise we always have the metric
with torch.no_grad(): | ||
while True: | ||
# Wait for previous async all_reduce to complete | ||
if pending_work is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am thinking we could abstract most of it into some utility and have this (feel free to change var names)
epoch_incremented, next_max_epoch = False, None
with torch.no_grad():
while True:
# check if epoch incremented before getting new batch.
# If so, stop iterating on the dataset
epoch_incremented: bool = check_if_epoch_incremented(batch, next_max_epoch)
if epoch_incremented:
logger.info("bla bla bla")
break
# get next batch
batch = next_batch
next_batch = next(val_dataloader)
# start non-blocking all_reduce for next batches epoch
next_max_epoch: futures = get_distributed_max_epoch(next_batch)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not 100% sure this works. I think that get_distributed_max_epoch
may need to return a tensor and futures?
for model_part in self.model_parts: | ||
model_part.train() | ||
|
||
avg_loss = total_loss / max(num_batches, 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This might be incorrect. If we have one sample with 10 tokens and another sample with 100 tokens, we have to /110, not by 2. @ebsmothers can you confirm?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This kinda depends. One interpretation is that the val loss we report is the average over all batches in the val dataloader. In that case this would be the correct implementation. A second interpretation would be that the val loss we report is actually the aggregate loss over all tokens in the val dataset. This is the more "correct" one, as @felipemello1 points out. It especially matters for training.. btw this is why I did not enable gradient accumulation initially, as it needs some special handling. See e.g. the "Long digression" section in the summary of meta-pytorch/torchtune#1917. Since then titan has made some strides in this direction, see this function and this issue.
So I think the ideal state is that we use token-normalized loss (rather than batch-normalized as has been done here) for both training and validation. But for training it will likely require a little more work to fully support in titan. For validation it's more straightforward, you can see this snippet in torchtune's validation logic.
for model_part in self.model_parts: | ||
model_part.eval() | ||
|
||
val_dataloader = iter(self.val_dataloader) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
did you have a chance to consider the case for multidataset? what happens in this case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the delay. Left some comments/suggestions. We would need to test it in some distributed capacity. Were you able to run it for >1 node and confirm that it stopped right after 1 epoch?
@felipemello1 Sorry I missed this before. We do have this utility but not sure if that's sufficient here. Another commonly-used class is FSDPTest, which handles a lot of the setup and teardown logic for a distributed test. |
dataset_val_config = self.job_config.get("dataset_val", {}) | ||
self.val_dataloader = self.setup_data( | ||
dataset_path=dataset_val_config.get("path", dataset_config.get("path")), | ||
dataset_split=dataset_val_config.get("split", dataset_config.get("split")), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit but I don't like these nested .get calls. It also seems strange that we would fallback to validation on the training set. Personally I would just recommend checking if validation is enabled, and if it's not, don't even set up the validation dataloader at all.
return metric.value | ||
return None | ||
|
||
async def evaluate(self) -> dict[str, float]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In addition to @felipemello1's more detailed comments, one higher-level point: this eval implementation is adding a lot of code to the main.py file, which as the entry point is something that everyone will have to read. (Specifically this PR alone has increased the total LoC by more than 50%, and the evaluate method alone is more than 100 lines due to boundary checking, edge case handling, etc.) I would like to see if we can find a more minimal way to introduce eval that doesn't expose the user to so much code complexity.
I think Felipe's suggestions of offloading to the dataset class, utilities, etc. are valuable. But would also like to re-raise the option of simplifying by only allowing eval for a fixed number of steps (at least for a first pass). Not gonna block on this, if we can do the cross-epoch accounting in a bit more of a clean, minimal way I am all for it.
Add periodic evaluation during training with epoch-aware synchronization
Added evaluation functionality to the SFT training recipe with proper multi-rank synchronization and epoch completion detection.
Changes:
Core Evaluation Features
Configurable evaluation interval: Added
eval_interval
andeval_steps
parameters to control when and how much to evaluateeval_interval
: Number of training steps between evaluations (defaults tofloat('inf')
to disable eval when not configured)eval_steps
: Number of validation batches to evaluate per evaluation run (defaults to0
for unlimited - runs one full epoch)Validation dataloader: Set up separate validation dataloader using the last 10% of the train split
Forward-only pass: Implemented
forward_only()
method for evaluation without gradient computation, supporting both pipeline parallel and non-PP configurationsEpoch-Aware Evaluation with Multi-Rank Synchronization
Epoch completion detection: Evaluates for exactly one complete epoch by monitoring
batch["metrics"]
for epoch incrementsnum_epochs
from batch metadata to detect when validation dataset completes one full passNon-blocking all_reduce pattern: Synchronizes epoch completion across all ranks without blocking computation
async_op=True
all_reduce on next batch's epoch while GPU computes current batch's lossIntegration
eval_interval
steps during trainingeval_steps > 0
, it acts as a cap (useful for quick validation checks or when epoch metadata is unavailable)Usage:
Configure in your YAML config file:
If eval_intervaland eval_steps are not set, evaluation is automatically disabled.
Testing:
Comprehensive test suite (test_evaluate.py) validates:
✅ Epoch extraction from batch metadata
✅ Single epoch completion detection
✅ eval_steps cap enforcement
✅ Empty/single batch edge cases
✅ Async all_reduce pattern behavior
✅ Multi-rank synchronization logic
✅ Prefetch pattern correctness
All 14 tests pass successfully.
Algorithm Details:
The non-blocking evaluation loop follows this pattern:
Iteration N:
Iteration N+1:
This overlaps network communication with GPU computation for better performance, while ensuring all ranks stop at the same point.
This updated description captures: