Skip to content

Conversation

DNXie
Copy link
Member

@DNXie DNXie commented Aug 14, 2025

  • Make dataset path and split configurable
  • Add validation dataloader
  • Add validation loop

Added dataset_val and validation sections in the config file:

validation:
  local_batch_size: 1
  freq: -1  # Change to a positive number to enable validation
  steps: 200  # Max steps to run validation. Validation disabled if negative.

dataset_val:
  path: yahma/alpaca-cleaned
  split: train[95%:]

Test:

uv run forge run --nproc_per_node 2 apps/sft/main.py --config apps/sft/llama3_8b.yaml

The validation freq (freq) is set to 2 (run validation every two steps):

1|Loss: 1.0664817094802856:   0%|          | 2/1000 [00:06<57:53,  3.48s/it]
Validation loss: 1.311207890510559
Running validation Loss: 1.3039: : 200it [00:24,  8.28it/s]
Validation loss: 1.311207890510559
3|Loss: 1.2771787643432617:   0%|          | 4/1000 [00:31<2:15:20,  8.15s/it]
Validation loss: 1.3083338737487793
Running validation Loss: 1.2985: : 200it [00:24,  8.25it/s]
Validation loss: 1.3083338737487793
5|Loss: 1.3110204935073853:   1%|          | 6/1000 [00:57<2:37:35,  9.51s/it]
Validation loss: 1.2877949476242065
Running validation Loss: 1.2748: : 199it [00:24,  8.32it/s]
Validation loss: 1.2877949476242065
7|Loss: 1.3947221040725708:   1%|          | 8/1000 [01:22<2:46:21, 10.06s/it]
Validation loss: 1.274566411972046
Running validation Loss: 1.2768: : 200it [00:24,  8.12it/s]
Validation loss: 1.274566411972046
9|Loss: 1.1953884363174438:   1%|          | 10/1000 [01:48<2:50:45, 10.35s/it]
Validation loss: 1.2020115852355957
Running validation Loss: 1.1951: : 200it [00:24,  8.17it/s]
Validation loss: 1.2020115852355957
11|Loss: 1.2392996549606323:   1%|          | 12/1000 [02:13<2:53:13, 10.52s/it]
Validation loss: 1.1668553352355957
Running validation Loss: 1.1603: : 200it [00:24,  8.12it/s]
Validation loss: 1.1668553352355957
13|Loss: 1.041933536529541:   1%|▏         | 14/1000 [02:39<2:54:35, 10.62s/it]]
Validation loss: 1.0708351135253906
Running validation Loss: 1.0622: : 200it [00:24,  8.22it/s]
Validation loss: 1.0708351135253906
15|Loss: 1.0107098817825317:   2%|▏         | 16/1000 [03:05<2:54:47, 10.66s/it]

Tested steps = 10 and 200 (exhaust the dataset).

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Aug 14, 2025
apps/sft/main.py Outdated
Comment on lines +107 to +113
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing we should think about is how to support additional args beyond those we've already hardcoded. E.g. in #50 we also need to pass data_files. (This is more of a config system question so it's OK to punt on it for now, but one path is to use something like instantiate for this, you can see this section in the torchtune docs for an example)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can support passing file paths. Which one (data_files or path) should it prioritize? For example, if user pass both data_files and path

@DNXie DNXie changed the title make dataset configurable make dataset configurable and add validation loop Aug 15, 2025
apps/sft/main.py Outdated
Copy link
Member Author

@DNXie DNXie Aug 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ebsmothers
I try to do this like torchtune did

current_num_tokens = (
                    batch["labels"] != self._loss_fn.ignore_index
                ).sum()

but the current self.loss_fn doesn't have ignore_index

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tbh I think this is something we kinda overparametrized in torchtune. You can just use the constant CROSS_ENTROPY_IGNORE_IDX

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Fixed.

apps/sft/main.py Outdated
Comment on lines 313 to 316
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: or BlockMask with flexattention enabled

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Fixed.

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a few more comments, after that I think this should be good

apps/sft/main.py Outdated
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's move this out of main.py. Personally I would put in src/forge/utils.py for now, we can relocate later as needed.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved to utils.py

apps/sft/main.py Outdated
Comment on lines 46 to 49
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can do without the conditional BlockMask logic for now. It has been in PyTorch stable for several releases now so we can just assume it's present (also I suspect we won't be running on non-CUDA devices for a bit either)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated. Pls check.

apps/sft/main.py Outdated
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can remove this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed.

apps/sft/main.py Outdated
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This won't work with pipeline parallel, right? Since there the backward happens inside of step(), I think we will need to handle that case differently

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated the code to use eval when not doing backward.

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK last comments I swear 😅

apps/sft/main.py Outdated
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to do this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed self.model

apps/sft/main.py Outdated
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fine as a one-off, but let's add a TODO to do this in a nicer way. One of the things I don't like about torchtune is that we've littered the training script with similar .get(value, default) statements. Titan configs come with built-in default values, we can consider landing these fields directly in ForgeJobConfig

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have removed this.

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can probably remove all the instances of infinite now (not only in sft/main.py, but also in hf_dataset.py and sft_dataset.py). Other than that looks good to me!

@DNXie DNXie merged commit 5a3807e into meta-pytorch:main Aug 22, 2025
4 checks passed
@DNXie DNXie deleted the add_data_config branch September 10, 2025 19:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants