Skip to content

Conversation

@kmehant
Copy link
Collaborator

@kmehant kmehant commented Sep 24, 2025

@github-actions
Copy link

Thanks for making a pull request! 😃
One of the maintainers will review and advise on the next steps.

@kmehant kmehant changed the title [DO NOT MERGE] feat: add online data mixing plugin feat: [DO NOT MERGE] add online data mixing plugin Sep 24, 2025
@github-actions github-actions bot added the feat label Sep 24, 2025
def _process_dataset_configs(
self, dataset_configs: List[DataSetConfig]
self, dataset_configs: List[DataSetConfig], odm_config=None
) -> Union[Dataset, IterableDataset]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you want to update the annotation to include DatasetDict or dict?

train_datasets_dict will be a dict

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@romitjain May I know which annotation you are talking about? The return type of _process_dataset_configs ? In our case it will be IterableDataset isn't it?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For _process_dataset_configs, if odm_config is not None, then we return from _process_datasets_for_odm, which is returning tuple[dict, dict]

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see right :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed @romitjain, types have turned out to be complex than I thought.

@kmehant
Copy link
Collaborator Author

kmehant commented Sep 25, 2025

NOTE: format/lint error should go once we have fms_acceleration_odm package available.

@kmehant kmehant force-pushed the odm-plugin branch 3 times, most recently from f86e1e6 to 4729c6f Compare September 29, 2025 11:33
@kmehant kmehant changed the title feat: [DO NOT MERGE] add online data mixing plugin feat: add online data mixing plugin Sep 29, 2025
ashokponkumar
ashokponkumar previously approved these changes Sep 29, 2025
kmehant added 14 commits October 6, 2025 22:33
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
kmehant added 14 commits October 6, 2025 22:33
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Copy link
Collaborator

@dushyantbehl dushyantbehl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Requesting minor revision of the documentation and restructing of code.


Note: If a user specifies data sampling they can expect the datasets to be mixed and individual samples in the dataset to not be broken unless the max_seq_len argument is smaller than the length of individual samples in the dataset

### Online Data Mixing
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make ODM a separate document so its easy for users to find.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, I have made it into a new doc and changed references accordingly.

- name: dataset_1
split:
train: 0.8
validation: 0.2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a comment after the ratio that this will be overloaded as a reward dataset too in odm.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is not explained in the documentation we should do this too

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comments are added and same is explained in docs as well.



@pytest.mark.skipif(
not is_fms_accelerate_available(plugins="odm"),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we enabling installing odm plugin by default in the tox.ini file?
if not i am assmuming these tests are run separately

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No we dont want to install the odm plugin by default. Yes I have run these tests on my end:

Check the screenshots below:

Screenshot 2025-10-07 at 5 48 48 PM Screenshot 2025-10-07 at 5 50 55 PM

train_datasets_dict[d.name] = raw[train_split]
if eval_split in raw:
eval_datasets_dict[d.name] = raw[eval_split]
return train_datasets_dict, eval_datasets_dict
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we just looking to filter train and test splits from the dataset here?
Do you need the returned values to be a raw dict instead of a container?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we just looking to filter train and test splits from the dataset here?

Yes

Do you need the returned values to be a raw dict instead of a container?

dict would be good over datasetdict but both ways it should work.

if data_args.do_dataprocessing_only:
if odm_config:
raise ValueError(
"data processing with online data mixing is not currently supported"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason we do not support data processing only mode for ODM config?

We can just dump the processed datasets and return right?

ODM then has to applied while training too but is there a fundamental problem in compatibility?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason we do not support data processing only mode for ODM config?

We can just dump the processed datasets and return right?

ODM then has to applied while training too but is there a fundamental problem in compatibility?

True, I agree, I have removed this restriction

is_tokenized_dataset = is_pretokenized_dataset(train_dataset or eval_dataset)

data_collator = None
if odm_config:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible to wrap all of ODM stuff in a separate function and call it once inside?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done wrapped them here - func process_dataargs_odm

if odm_config:
# Third Party
# pylint: disable=import-outside-toplevel
if not is_fms_accelerate_available(plugins="odm"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we do this check at the top as first thing?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you mention where exactly?

reward_type=odm_config.odm.reward_type,
)
dataset_kwargs["skip_prepare_dataset"] = True
train_args.accelerator_config = {"split_batches": True}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is split_batches needed for ODM?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, by design its needed.

)

odm_config = None
if data_args.data_config_path:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we please do this inside the process_dataargs function?

We want to keep sft_trainer clean from any data related functionality.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have thought about this, but we need to prepare odm_config variable for fms-acceleration plugin preparation step as well which happens before process_dataargs function. So its hard to keep this piece of code within process_dataargs which happens later point in time.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kmehant
can process_dataargs not initialize ODM framework and then return anything if needed...as far as I see the only thing we return is a dataset of type ODM so why can't we do the ODM framework initialization inside process_dataargs and return train and eval datasets as usual just ODM this time.

is_padding_free=is_padding_free,
processor=processor,
is_multipack=is_multipack,
odm_config=odm_config,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can keep the data config load inside process data args and initialize odm inside the process_dataargs possibly post this

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

related - #612 (comment)

Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Copy link
Collaborator

@dushyantbehl dushyantbehl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing the earlier comments...requesting some more clarifications.

- name: dataset_2
split:
train: 0.6
validation: 0.2 # validation set is also used in reward computation when reward_type is validation_loss.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
validation: 0.2 # validation set is also used in reward computation when reward_type is validation_loss.
validation: 0.2 # validation set is also used in ODM reward computation when reward_type is validation_loss.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done!

- name: dataset_1
split:
train: 0.8
validation: 0.2 # validation set is also used in reward computation when reward_type is validation_loss.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
validation: 0.2 # validation set is also used in reward computation when reward_type is validation_loss.
validation: 0.2 # validation set is also used in ODM reward computation when reward_type is validation_loss.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There were couple more places so modified in all the files

- `--multipack`: technique for *multi-gpu training* to balance out number of tokens processed in each device, to minimize waiting time.
- [fast_moe_config](./tuning/config/acceleration_configs/fast_moe.py) (experimental):
- `--fast_moe`: trains MoE models in parallel with [Scatter MoE kernels](https://github.com/foundation-model-stack/fms-acceleration/tree/main/plugins/accelerated-moe#fms-acceleration-for-mixture-of-experts), increasing throughput and decreasing memory usage.
- [odm_config](./tuning/config/acceleration_configs/odm.py) (experimental): See [advanced data preprocessing](./advanced-data-preprocessing.md#online-data-mixing) for usage with data_config. This plugin allows dynamically mixing datasets online during training adapting to training signals.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you want to link the pytorch poster link here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specificm language governing permissions and
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please fix this typo

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

specificm this is the typo isn't? What you want me to fix?

)

# pylint: disable=import-outside-toplevel
if not is_fms_accelerate_available(plugins="odm"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move this to line 508 top of this function.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

processed_datasets.append((d, raw_datasets))

if odm_config:
return self._process_datasets_for_odm(processed_datasets)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I missed this last time but is sampling and concatenation of datasets not compatible with ODM?

Can we let the normal data processing perform its function i.e give a processed dataset and then wrap ODM on top...

The way I see is you don't need to modify code in data_processors.py the function _process_datasets_for_odm gets called after the datasets are returned by _process_dataset_configs and then you apply odm processing by calling this function inside setup_dataprocessor.py

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sampling and concatenation should not be done with ODM since that would be handled by ODM dataloader.

raise RuntimeError(f"Failed to dump dataset due to error {e}") from e


def process_dataargs_odm(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: could we call this setup_train_dataset_for_odm

because you are not touching eval_dataset

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, done!

processor: AutoProcessor = None,
odm_config: ODMConfig = None,
train_dataset: Dict = None,
eval_dataset: Dict = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we rename this to reward_dataset and pass eval_dataset here after mentioning a comment in the code

)

if data_args.data_config_path:
train_dataset, eval_dataset, dataset_text_field = process_dataconfig_file(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can potentially return odm_config from this function...or you can return the original data_config from this function to odm_config from inside it here.

We can initialize Acceleration framework ODM here too i think...this would save us loading data config twice

)

odm_config = None
if data_args.data_config_path:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kmehant
can process_dataargs not initialize ODM framework and then return anything if needed...as far as I see the only thing we return is a dataset of type ODM so why can't we do the ODM framework initialization inside process_dataargs and return train and eval datasets as usual just ODM this time.

Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
@kmehant
Copy link
Collaborator Author

kmehant commented Oct 7, 2025

@dushyantbehl All the above comments are addressed and pushed. Please take a pass!

@kmehant
Copy link
Collaborator Author

kmehant commented Oct 7, 2025

Screenshot 2025-10-07 at 10 53 37 PM

tests on latest changes pass

Copy link
Collaborator

@dushyantbehl dushyantbehl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@dushyantbehl dushyantbehl merged commit 4ec1340 into foundation-model-stack:main Oct 8, 2025
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants