-
Notifications
You must be signed in to change notification settings - Fork 16
make dataset configurable and add validation loop #54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
apps/sft/main.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One thing we should think about is how to support additional args beyond those we've already hardcoded. E.g. in #50 we also need to pass data_files
. (This is more of a config system question so it's OK to punt on it for now, but one path is to use something like instantiate for this, you can see this section in the torchtune docs for an example)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can support passing file paths. Which one (data_files
or path
) should it prioritize? For example, if user pass both data_files
and path
apps/sft/main.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ebsmothers
I try to do this like torchtune did
current_num_tokens = (
batch["labels"] != self._loss_fn.ignore_index
).sum()
but the current self.loss_fn
doesn't have ignore_index
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tbh I think this is something we kinda overparametrized in torchtune. You can just use the constant CROSS_ENTROPY_IGNORE_IDX
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Fixed.
apps/sft/main.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: or BlockMask with flexattention enabled
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Fixed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a few more comments, after that I think this should be good
apps/sft/main.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's move this out of main.py. Personally I would put in src/forge/utils.py for now, we can relocate later as needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved to utils.py
apps/sft/main.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can do without the conditional BlockMask logic for now. It has been in PyTorch stable for several releases now so we can just assume it's present (also I suspect we won't be running on non-CUDA devices for a bit either)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated. Pls check.
apps/sft/main.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can remove this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed.
apps/sft/main.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This won't work with pipeline parallel, right? Since there the backward happens inside of step(), I think we will need to handle that case differently
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated the code to use eval
when not doing backward.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK last comments I swear 😅
apps/sft/main.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need to do this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed self.model
apps/sft/main.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is fine as a one-off, but let's add a TODO to do this in a nicer way. One of the things I don't like about torchtune is that we've littered the training script with similar .get(value, default) statements. Titan configs come with built-in default values, we can consider landing these fields directly in ForgeJobConfig
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have removed this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can probably remove all the instances of infinite
now (not only in sft/main.py, but also in hf_dataset.py and sft_dataset.py). Other than that looks good to me!
Added
dataset_val
andvalidation
sections in the config file:Test:
The validation freq (
freq
) is set to 2 (run validation every two steps):Tested steps = 10 and 200 (exhaust the dataset).