Skip to content

Conversation

@felipemello1
Copy link
Contributor

@felipemello1 felipemello1 commented Nov 6, 2025

Continuation of the work by @HosseinKaviani-H in #404

TDLR:

  • Adds eval loop to SFT
  • Adds non-blocking equivalent to drop_last and stops eval after 1 epoch.
  • Adds config for multi eval dataset and multi train dataset (train doest not support >1 ds yet, but it will in a different PR)
  • [FIX] fixes and adds new unit tests covering blind spots in our dataset for when we have TP/CP
  • [FIX] Enables dataset to reset by doing iter(dataset)
image

Context:

In forge we use infinite iterable datasets (not map-style). There are advantages to it, such as:
a) Streaming / Not holding the entire dataset in memory;
b) Easy manipulation of data, e.g. while True: do_something(next(iter))
c) dataset can be generated on the fly, e.g. replay buffer

However, there are also challenges:
a) We do not know the size of the dataset in advance;
b) We don't know epoch boundaries (the dataset resets automatically on exhaustion. This is done so that we don't have to deal with potential hangs from different ranks not getting enough samples when the dataset is exhausted)

Original problem:

For validation, we want to run only 1 epoch. In map-style datasets, this is easy: i) break after one iteration over the loop; ii) set dataloader(drop_last=True) to avoid hangs;

As discussed above, this is not possible with infinite iterable datasets.

To identify epoch boundaries, our dataset implementation returns a Metric num_epochs. We can use to to easily verify if we started a new epoch, and stop there.

for batch in dataloader:
	if any(metric.value>0 for metric in batch["metrics"] if metric.key=="num_epochs"):
		break

However, in a distributed setting, we may have len(dataset) % num_ranks != 0. This means that some ranks may be on epoch 0 while others are already in epoch 1.

To avoid hangs, all ranks must stop at the same time. This means that we need to do some sort of all_reduce to know if at least one rank has seen epoch==1, introducing communication overhead and blocking the forward pass.

Solution:

This PR implements StopAfterOneEpoch(dataloader), that fetches one batch in advance and does the all_reduce async, overlapping communications. The utility elegantly abstracts it away from the user.

Issues found/fixed during this implementation:

  • HfIterableDataset sharded the data on all ranks, not only dp_ranks. This means that TP ranks were getting different batches, instead of repeated.
  • The datasets had to be reset after every eval. Some changes had to be made so that doing iter(dataset) provided a fresh new iter with the original state. This is much faster than creating a new dataset on every eval loop;

Hossein Kavianihamedani and others added 5 commits October 27, 2025 11:51
- Add eval_utils.py with run_evaluation() function for multi-dataset evaluation
- Update main.py to support multi-dataset configuration and evaluation
  - Add validation config settings (enabled, eval_interval, eval_steps)
  - Refactor setup() to support dataset_val.datasets structure
  - Add unified forward() method with compute_gradients flag
  - Add evaluate() method that calls run_evaluation()
- Update llama3_8b.yaml with multi-dataset configuration
- Fix extract_epoch_from_batch() to use 'key' attribute instead of 'metric_name'
- Simplify epoch tracking: compare consecutive batches instead of tracking from start
- Remove starting_epoch variable - no longer needed
- Update start_epoch_sync() to use boolean epoch_changed instead of epoch_increment
- Add better logging for epoch changes and tracking status
- Epoch sync now works correctly with the actual metric structure
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Nov 6, 2025
@felipemello1 felipemello1 marked this pull request as draft November 6, 2025 21:43
@felipemello1 felipemello1 marked this pull request as ready for review November 7, 2025 19:16
@felipemello1 felipemello1 changed the title [WIP][SFT Eval ] Add eval to SFT script [SFT Eval ] Add eval to SFT script Nov 7, 2025
@felipemello1 felipemello1 changed the title [SFT Eval ] Add eval to SFT script [wip][SFT Eval ] Add eval to SFT script Nov 7, 2025
dataset_name: str | None = None,
filter_fn: Callable | None = None,
filter_kwargs: dict[str, Any] | None = None,
dp_mesh: dist.ProcessGroup | None = None,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Explaining changes]

With iterable dataset, we shard the dataset and split it across ranks. This is done by calling:

ds = split_dataset_by_node(ds, rank=rank, world_size=world_size)

Before: I was using the global rank

Problem: If using TP/CP/PP etc, each rank would get a different data point. This is wrong. We need to split it per dp rank. All ranks within the same dp should get the same data point.

Solution: pass dp_mesh

# Internal state for resumption
# _start_epoch: The epoch to start from. Updated on resume from ckpt.
# useful when doing iter(ds), which restarts dataset from original state.
self._start_epoch = 0
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Explaining changes]

We want to restart the dataset on every eval. We can do this by calling iter(ds), which is cheap.

In def __iter__ we then do:

self._num_epochs = self._start_epoch
self._ds.set_epoch(self._num_epochs) #used to preserve shuffle order

In other words: we need to add self._start_epoch so we know where to reset to. You may ask "why not just always set it to 0?". Because for training, we may have resumed from checkpoint and want iter(ds) start from elsewhere. Thats why we do:

def load_state_dict(self, state_dict: dict[str, Any]) -> None:
        self._start_epoch = state_dict["num_epochs"]

Comment on lines +452 to +453
self._reset_packer_state()
self._iterator = iter(self.dataset)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Explaining changes]

Similar to hf_dataset.py changes, when we call iter(ds), we always want it to restart from its original state. There is no need for the extra checks that were here.

Copy link
Contributor

@HosseinKaviani-H HosseinKaviani-H left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great to me. Thanks for the PR. Just had a few minor comments/questions.

P.S. Do you have test results on TP/CP where ranks get different samples?

break

# Move tensors to device
for key, value in batch.items():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have a helper function to do this.

apps/sft/main.py Outdated
num_steps += 1

# Log progress (rank 0 only)
if num_steps % 50 == 0:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hardcoded to log every 50 steps?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

made it an arg

for batch in batch_iter:
# Check max_eval_steps limit
if (
self.max_eval_steps is not None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if max eval steps > num steps per epoch?

Copy link
Contributor Author

@felipemello1 felipemello1 Nov 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

batch_iter will stop first, i can add a comment to clarify

)


class StopAfterOneEpoch:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't fully understand why this is needed. Why is this only showing up for evaluation?

Also, IIUC the motivation seems two-fold:

  1. Avoid "hangs"
  2. Fetch batches in advance and overlap all_reduce

The first is a blocker and should be addressed. The later is an optimization and I do not think is necessary or should be done in this PR.

Copy link
Contributor Author

@felipemello1 felipemello1 Nov 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this only showing up for evaluation?

We do not do "epochs" for training. We do num of steps. It has to be done this way for training because, when doing multidataset, the concept of "epoch" is unclear.

For evaluation, however, we want to do a single epoch on the dataset.

Fetch batches in advance and overlap all_reduce

Correct. We could get rid of the utility and just have the all_reduce at the start of every loop. Evan was strongly against it though, hence the utility :/

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When did @ebsmothers weigh in on this? I'd personally rather no redirect and keep as much logic visible as possible.

I leave it to your discretion, but I just will reiterate the guiding principles that I think we should only address immediate unblocking for now and have a separate PR for any optimizations (all while minimizing redirects as much as we can).

Copy link
Member

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still some comments, but approved

self.rank_should_record_loss = False

# Logging frequency
self.log_every_n_steps = self.job_config.get("log_every_n_steps", 10)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make this mandatory, no getter?

# Load eval datasets
eval_config = self.job_config.get("eval", {})
self.val_dataloaders = {}
self.eval_every_n_steps = eval_config.get("eval_every_n_steps", None)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No getter

eval_config = self.job_config.get("eval", {})
self.val_dataloaders = {}
self.eval_every_n_steps = eval_config.get("eval_every_n_steps", None)
max_eval_steps = eval_config.get("max_eval_steps", None)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No getter

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so put everything in the yaml?

labels: torch.Tensor,
skip_backward: bool = False,
) -> torch.Tensor:
"""Forward pass with optional backward."""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: No need for this comment

# TODO: PP+FSDP unexpectedly puts the loss back to the CPU
loss = (
torch.mean(torch.stack(losses)).to(self.device)
torch.sum(torch.stack(losses)).to(self.device)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you change this?

Copy link
Contributor Author

@felipemello1 felipemello1 Nov 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for (cp=2, tp=2) or (tp=4) or tp=1, i would get the same loss: 1.30

Using pp=2, i would get 0.65, pp=4 i would get 0.325

Changing it to sum fixed the loss scale back to 1.30. I need to print(losses), and understand whats going on

loss_val = loss.item()
record_metric("ForgeSFTRecipe/train_step/loss", loss_val, Reduce.MEAN)
if self.current_step % self.log_every_n_steps == 0:
logger.info(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought this was handled by the MetricLogger?

)


class StopAfterOneEpoch:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When did @ebsmothers weigh in on this? I'd personally rather no redirect and keep as much logic visible as possible.

I leave it to your discretion, but I just will reiterate the guiding principles that I think we should only address immediate unblocking for now and have a separate PR for any optimizations (all while minimizing redirects as much as we can).

@felipemello1
Copy link
Contributor Author

comments are minor. Will merge and address in a follow up this week

HosseinKaviani-H pushed a commit to HosseinKaviani-H/forge that referenced this pull request Nov 13, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants