-
Notifications
You must be signed in to change notification settings - Fork 65
feat: add online data mixing plugin #612
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Thanks for making a pull request! 😃 |
tuning/data/data_processors.py
Outdated
| def _process_dataset_configs( | ||
| self, dataset_configs: List[DataSetConfig] | ||
| self, dataset_configs: List[DataSetConfig], odm_config=None | ||
| ) -> Union[Dataset, IterableDataset]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you want to update the annotation to include DatasetDict or dict?
train_datasets_dict will be a dict
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@romitjain May I know which annotation you are talking about? The return type of _process_dataset_configs ? In our case it will be IterableDataset isn't it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For _process_dataset_configs, if odm_config is not None, then we return from _process_datasets_for_odm, which is returning tuple[dict, dict]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see right :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed @romitjain, types have turned out to be complex than I thought.
|
NOTE: format/lint error should go once we have fms_acceleration_odm package available. |
f86e1e6 to
4729c6f
Compare
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
dushyantbehl
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Requesting minor revision of the documentation and restructing of code.
docs/advanced-data-preprocessing.md
Outdated
|
|
||
| Note: If a user specifies data sampling they can expect the datasets to be mixed and individual samples in the dataset to not be broken unless the max_seq_len argument is smaller than the length of individual samples in the dataset | ||
|
|
||
| ### Online Data Mixing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we make ODM a separate document so its easy for users to find.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, I have made it into a new doc and changed references accordingly.
| - name: dataset_1 | ||
| split: | ||
| train: 0.8 | ||
| validation: 0.2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a comment after the ratio that this will be overloaded as a reward dataset too in odm.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this is not explained in the documentation we should do this too
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comments are added and same is explained in docs as well.
|
|
||
|
|
||
| @pytest.mark.skipif( | ||
| not is_fms_accelerate_available(plugins="odm"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we enabling installing odm plugin by default in the tox.ini file?
if not i am assmuming these tests are run separately
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| train_datasets_dict[d.name] = raw[train_split] | ||
| if eval_split in raw: | ||
| eval_datasets_dict[d.name] = raw[eval_split] | ||
| return train_datasets_dict, eval_datasets_dict |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we just looking to filter train and test splits from the dataset here?
Do you need the returned values to be a raw dict instead of a container?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we just looking to filter train and test splits from the dataset here?
Yes
Do you need the returned values to be a raw dict instead of a container?
dict would be good over datasetdict but both ways it should work.
tuning/data/setup_dataprocessor.py
Outdated
| if data_args.do_dataprocessing_only: | ||
| if odm_config: | ||
| raise ValueError( | ||
| "data processing with online data mixing is not currently supported" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any reason we do not support data processing only mode for ODM config?
We can just dump the processed datasets and return right?
ODM then has to applied while training too but is there a fundamental problem in compatibility?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any reason we do not support data processing only mode for ODM config?
We can just dump the processed datasets and return right?
ODM then has to applied while training too but is there a fundamental problem in compatibility?
True, I agree, I have removed this restriction
| is_tokenized_dataset = is_pretokenized_dataset(train_dataset or eval_dataset) | ||
|
|
||
| data_collator = None | ||
| if odm_config: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it possible to wrap all of ODM stuff in a separate function and call it once inside?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done wrapped them here - func process_dataargs_odm
tuning/data/setup_dataprocessor.py
Outdated
| if odm_config: | ||
| # Third Party | ||
| # pylint: disable=import-outside-toplevel | ||
| if not is_fms_accelerate_available(plugins="odm"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't we do this check at the top as first thing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you mention where exactly?
tuning/data/setup_dataprocessor.py
Outdated
| reward_type=odm_config.odm.reward_type, | ||
| ) | ||
| dataset_kwargs["skip_prepare_dataset"] = True | ||
| train_args.accelerator_config = {"split_batches": True} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is split_batches needed for ODM?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, by design its needed.
| ) | ||
|
|
||
| odm_config = None | ||
| if data_args.data_config_path: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we please do this inside the process_dataargs function?
We want to keep sft_trainer clean from any data related functionality.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have thought about this, but we need to prepare odm_config variable for fms-acceleration plugin preparation step as well which happens before process_dataargs function. So its hard to keep this piece of code within process_dataargs which happens later point in time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kmehant
can process_dataargs not initialize ODM framework and then return anything if needed...as far as I see the only thing we return is a dataset of type ODM so why can't we do the ODM framework initialization inside process_dataargs and return train and eval datasets as usual just ODM this time.
| is_padding_free=is_padding_free, | ||
| processor=processor, | ||
| is_multipack=is_multipack, | ||
| odm_config=odm_config, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can keep the data config load inside process data args and initialize odm inside the process_dataargs possibly post this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
related - #612 (comment)
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
dushyantbehl
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing the earlier comments...requesting some more clarifications.
| - name: dataset_2 | ||
| split: | ||
| train: 0.6 | ||
| validation: 0.2 # validation set is also used in reward computation when reward_type is validation_loss. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| validation: 0.2 # validation set is also used in reward computation when reward_type is validation_loss. | |
| validation: 0.2 # validation set is also used in ODM reward computation when reward_type is validation_loss. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done!
| - name: dataset_1 | ||
| split: | ||
| train: 0.8 | ||
| validation: 0.2 # validation set is also used in reward computation when reward_type is validation_loss. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| validation: 0.2 # validation set is also used in reward computation when reward_type is validation_loss. | |
| validation: 0.2 # validation set is also used in ODM reward computation when reward_type is validation_loss. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There were couple more places so modified in all the files
docs/tuning-techniques.md
Outdated
| - `--multipack`: technique for *multi-gpu training* to balance out number of tokens processed in each device, to minimize waiting time. | ||
| - [fast_moe_config](./tuning/config/acceleration_configs/fast_moe.py) (experimental): | ||
| - `--fast_moe`: trains MoE models in parallel with [Scatter MoE kernels](https://github.com/foundation-model-stack/fms-acceleration/tree/main/plugins/accelerated-moe#fms-acceleration-for-mixture-of-experts), increasing throughput and decreasing memory usage. | ||
| - [odm_config](./tuning/config/acceleration_configs/odm.py) (experimental): See [advanced data preprocessing](./advanced-data-preprocessing.md#online-data-mixing) for usage with data_config. This plugin allows dynamically mixing datasets online during training adapting to training signals. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you want to link the pytorch poster link here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
tuning/data/setup_dataprocessor.py
Outdated
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # See the License for the specificm language governing permissions and |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please fix this typo
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
specificm this is the typo isn't? What you want me to fix?
tuning/data/setup_dataprocessor.py
Outdated
| ) | ||
|
|
||
| # pylint: disable=import-outside-toplevel | ||
| if not is_fms_accelerate_available(plugins="odm"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we move this to line 508 top of this function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
| processed_datasets.append((d, raw_datasets)) | ||
|
|
||
| if odm_config: | ||
| return self._process_datasets_for_odm(processed_datasets) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I missed this last time but is sampling and concatenation of datasets not compatible with ODM?
Can we let the normal data processing perform its function i.e give a processed dataset and then wrap ODM on top...
The way I see is you don't need to modify code in data_processors.py the function _process_datasets_for_odm gets called after the datasets are returned by _process_dataset_configs and then you apply odm processing by calling this function inside setup_dataprocessor.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sampling and concatenation should not be done with ODM since that would be handled by ODM dataloader.
tuning/data/setup_dataprocessor.py
Outdated
| raise RuntimeError(f"Failed to dump dataset due to error {e}") from e | ||
|
|
||
|
|
||
| def process_dataargs_odm( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: could we call this setup_train_dataset_for_odm
because you are not touching eval_dataset
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
True, done!
tuning/data/setup_dataprocessor.py
Outdated
| processor: AutoProcessor = None, | ||
| odm_config: ODMConfig = None, | ||
| train_dataset: Dict = None, | ||
| eval_dataset: Dict = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we rename this to reward_dataset and pass eval_dataset here after mentioning a comment in the code
| ) | ||
|
|
||
| if data_args.data_config_path: | ||
| train_dataset, eval_dataset, dataset_text_field = process_dataconfig_file( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can potentially return odm_config from this function...or you can return the original data_config from this function to odm_config from inside it here.
We can initialize Acceleration framework ODM here too i think...this would save us loading data config twice
| ) | ||
|
|
||
| odm_config = None | ||
| if data_args.data_config_path: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kmehant
can process_dataargs not initialize ODM framework and then return anything if needed...as far as I see the only thing we return is a dataset of type ODM so why can't we do the ODM framework initialization inside process_dataargs and return train and eval datasets as usual just ODM this time.
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
|
@dushyantbehl All the above comments are addressed and pushed. Please take a pass! |
dushyantbehl
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM



Details provided in foundation-model-stack/fms-acceleration#152