-
Notifications
You must be signed in to change notification settings - Fork 57
[wip][SFT Eval ] Add eval to SFT script #536
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
- Add eval_utils.py with run_evaluation() function for multi-dataset evaluation - Update main.py to support multi-dataset configuration and evaluation - Add validation config settings (enabled, eval_interval, eval_steps) - Refactor setup() to support dataset_val.datasets structure - Add unified forward() method with compute_gradients flag - Add evaluate() method that calls run_evaluation() - Update llama3_8b.yaml with multi-dataset configuration
- Fix extract_epoch_from_batch() to use 'key' attribute instead of 'metric_name' - Simplify epoch tracking: compare consecutive batches instead of tracking from start - Remove starting_epoch variable - no longer needed - Update start_epoch_sync() to use boolean epoch_changed instead of epoch_increment - Add better logging for epoch changes and tracking status - Epoch sync now works correctly with the actual metric structure
| dataset_name: str | None = None, | ||
| filter_fn: Callable | None = None, | ||
| filter_kwargs: dict[str, Any] | None = None, | ||
| dp_mesh: dist.ProcessGroup | None = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[Explaining changes]
With iterable dataset, we shard the dataset and split it across ranks. This is done by calling:
ds = split_dataset_by_node(ds, rank=rank, world_size=world_size)
Before: I was using the global rank
Problem: If using TP/CP/PP etc, each rank would get a different data point. This is wrong. We need to split it per dp rank. All ranks within the same dp should get the same data point.
Solution: pass dp_mesh
| # Internal state for resumption | ||
| # _start_epoch: The epoch to start from. Updated on resume from ckpt. | ||
| # useful when doing iter(ds), which restarts dataset from original state. | ||
| self._start_epoch = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[Explaining changes]
We want to restart the dataset on every eval. We can do this by calling iter(ds), which is cheap.
In def __iter__ we then do:
self._num_epochs = self._start_epoch
self._ds.set_epoch(self._num_epochs) #used to preserve shuffle order
In other words: we need to add self._start_epoch so we know where to reset to. You may ask "why not just always set it to 0?". Because for training, we may have resumed from checkpoint and want iter(ds) start from elsewhere. Thats why we do:
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
self._start_epoch = state_dict["num_epochs"]
| self._reset_packer_state() | ||
| self._iterator = iter(self.dataset) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[Explaining changes]
Similar to hf_dataset.py changes, when we call iter(ds), we always want it to restart from its original state. There is no need for the extra checks that were here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great to me. Thanks for the PR. Just had a few minor comments/questions.
P.S. Do you have test results on TP/CP where ranks get different samples?
| break | ||
|
|
||
| # Move tensors to device | ||
| for key, value in batch.items(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have a helper function to do this.
apps/sft/main.py
Outdated
| num_steps += 1 | ||
|
|
||
| # Log progress (rank 0 only) | ||
| if num_steps % 50 == 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hardcoded to log every 50 steps?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
made it an arg
| for batch in batch_iter: | ||
| # Check max_eval_steps limit | ||
| if ( | ||
| self.max_eval_steps is not None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what if max eval steps > num steps per epoch?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
batch_iter will stop first, i can add a comment to clarify
| ) | ||
|
|
||
|
|
||
| class StopAfterOneEpoch: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't fully understand why this is needed. Why is this only showing up for evaluation?
Also, IIUC the motivation seems two-fold:
- Avoid "hangs"
- Fetch batches in advance and overlap all_reduce
The first is a blocker and should be addressed. The later is an optimization and I do not think is necessary or should be done in this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this only showing up for evaluation?
We do not do "epochs" for training. We do num of steps. It has to be done this way for training because, when doing multidataset, the concept of "epoch" is unclear.
For evaluation, however, we want to do a single epoch on the dataset.
Fetch batches in advance and overlap all_reduce
Correct. We could get rid of the utility and just have the all_reduce at the start of every loop. Evan was strongly against it though, hence the utility :/
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When did @ebsmothers weigh in on this? I'd personally rather no redirect and keep as much logic visible as possible.
I leave it to your discretion, but I just will reiterate the guiding principles that I think we should only address immediate unblocking for now and have a separate PR for any optimizations (all while minimizing redirects as much as we can).
joecummings
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still some comments, but approved
| self.rank_should_record_loss = False | ||
|
|
||
| # Logging frequency | ||
| self.log_every_n_steps = self.job_config.get("log_every_n_steps", 10) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we make this mandatory, no getter?
| # Load eval datasets | ||
| eval_config = self.job_config.get("eval", {}) | ||
| self.val_dataloaders = {} | ||
| self.eval_every_n_steps = eval_config.get("eval_every_n_steps", None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No getter
| eval_config = self.job_config.get("eval", {}) | ||
| self.val_dataloaders = {} | ||
| self.eval_every_n_steps = eval_config.get("eval_every_n_steps", None) | ||
| max_eval_steps = eval_config.get("max_eval_steps", None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No getter
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so put everything in the yaml?
| labels: torch.Tensor, | ||
| skip_backward: bool = False, | ||
| ) -> torch.Tensor: | ||
| """Forward pass with optional backward.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: No need for this comment
| # TODO: PP+FSDP unexpectedly puts the loss back to the CPU | ||
| loss = ( | ||
| torch.mean(torch.stack(losses)).to(self.device) | ||
| torch.sum(torch.stack(losses)).to(self.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why did you change this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for (cp=2, tp=2) or (tp=4) or tp=1, i would get the same loss: 1.30
Using pp=2, i would get 0.65, pp=4 i would get 0.325
Changing it to sum fixed the loss scale back to 1.30. I need to print(losses), and understand whats going on
| loss_val = loss.item() | ||
| record_metric("ForgeSFTRecipe/train_step/loss", loss_val, Reduce.MEAN) | ||
| if self.current_step % self.log_every_n_steps == 0: | ||
| logger.info( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought this was handled by the MetricLogger?
| ) | ||
|
|
||
|
|
||
| class StopAfterOneEpoch: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When did @ebsmothers weigh in on this? I'd personally rather no redirect and keep as much logic visible as possible.
I leave it to your discretion, but I just will reiterate the guiding principles that I think we should only address immediate unblocking for now and have a separate PR for any optimizations (all while minimizing redirects as much as we can).
|
comments are minor. Will merge and address in a follow up this week |
Continuation of the work by @HosseinKaviani-H in #404
TDLR:
drop_lastand stops eval after 1 epoch.Context:
In forge we use infinite iterable datasets (not map-style). There are advantages to it, such as:
a) Streaming / Not holding the entire dataset in memory;
b) Easy manipulation of data, e.g.
while True: do_something(next(iter))c) dataset can be generated on the fly, e.g. replay buffer
However, there are also challenges:
a) We do not know the size of the dataset in advance;
b) We don't know epoch boundaries (the dataset resets automatically on exhaustion. This is done so that we don't have to deal with potential hangs from different ranks not getting enough samples when the dataset is exhausted)
Original problem:
For validation, we want to run only 1 epoch. In map-style datasets, this is easy: i) break after one iteration over the loop; ii) set dataloader(drop_last=True) to avoid hangs;
As discussed above, this is not possible with infinite iterable datasets.
To identify epoch boundaries, our dataset implementation returns a Metric
num_epochs. We can use to to easily verify if we started a new epoch, and stop there.However, in a distributed setting, we may have
len(dataset) % num_ranks != 0. This means that some ranks may be on epoch 0 while others are already in epoch 1.To avoid hangs, all ranks must stop at the same time. This means that we need to do some sort of
all_reduceto know if at least one rank has seenepoch==1, introducing communication overhead and blocking the forward pass.Solution:
This PR implements
StopAfterOneEpoch(dataloader), that fetches one batch in advance and does theall_reduceasync, overlapping communications. The utility elegantly abstracts it away from the user.Issues found/fixed during this implementation:
HfIterableDatasetsharded the data on all ranks, not only dp_ranks. This means that TP ranks were getting different batches, instead of repeated.