Skip to content
Merged
Show file tree
Hide file tree
Changes from 35 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
c9a519e
feat: add odm plugin user facing config
kmehant Sep 22, 2025
d63eff0
feat: add odm plugin
kmehant Sep 22, 2025
6f95e56
feat: add odm plugin
kmehant Sep 22, 2025
8370997
feat: add odm plugin
kmehant Sep 22, 2025
88f68b1
feat: add odm plugin
kmehant Sep 22, 2025
a1b485b
feat: add odm plugin
kmehant Sep 22, 2025
be3afbd
feat: add odm plugin
kmehant Sep 22, 2025
6b0d4cd
feat: add odm plugin
kmehant Sep 23, 2025
ad35a69
feat: add odm plugin
kmehant Sep 23, 2025
245e2a2
feat: add odm plugin
kmehant Sep 23, 2025
b5cd791
feat: add odm_plugin
kmehant Sep 24, 2025
e8d748b
feat: add odm_plugin
kmehant Sep 24, 2025
3ebb844
feat: add odm_plugin
kmehant Sep 24, 2025
5953bfb
feat: add odm_plugin
kmehant Sep 24, 2025
9a8804a
feat: add odm_plugin
kmehant Sep 24, 2025
89b95a8
feat: add odm_plugin
kmehant Sep 24, 2025
bb5bc0d
feat: add odm_plugin
kmehant Sep 24, 2025
32f2982
feat: add odm_plugin
kmehant Sep 24, 2025
5756ad5
feat: add odm_plugin
kmehant Sep 24, 2025
8b02967
feat: add odm_plugin
kmehant Sep 24, 2025
2086d1e
feat: add odm_plugin
kmehant Sep 24, 2025
dadb6f0
feat: add odm_plugin
kmehant Sep 24, 2025
6ff740b
feat: add odm_plugin
kmehant Sep 24, 2025
beda5df
feat: add odm_plugin
kmehant Sep 24, 2025
304de74
fix: code refactor
kmehant Sep 25, 2025
9d69c85
fix: code refactor
kmehant Sep 25, 2025
a35ca56
fix: code refactor
kmehant Sep 25, 2025
0865130
fix: code refactor
kmehant Sep 25, 2025
c6f333e
fix: code refactor
kmehant Sep 25, 2025
c9945d1
fix: code refactor
kmehant Sep 25, 2025
12ed5fe
fix: code refactor
kmehant Sep 25, 2025
9fc8238
fix: function and argument types
kmehant Sep 29, 2025
0bff8b4
fix: lint errors
kmehant Sep 29, 2025
dc03331
fix: unit tests
kmehant Sep 29, 2025
d5db867
docs: add docs
kmehant Sep 29, 2025
e206f23
fix: refactor
kmehant Oct 7, 2025
81d5917
fix: refactor
kmehant Oct 7, 2025
ae7698e
fix: refactor
kmehant Oct 7, 2025
b72e421
feat: resume functionality
kmehant Oct 7, 2025
86eb54a
feat: resume functionality
kmehant Oct 7, 2025
5dd066f
feat: resume functionality
kmehant Oct 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions build/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -165,12 +165,14 @@ RUN --mount=type=cache,target=/home/${USER}/.cache/pip,uid=${USER_UID} \
# fms_acceleration_foak = Fused LoRA and triton kernels
# fms_acceleration_aadp = Padding-Free Flash Attention Computation
# fms_acceleration_moe = Parallelized Mixture of Experts
# fms_acceleration_odm = Online Data Mixing
RUN if [[ "${ENABLE_FMS_ACCELERATION}" == "true" ]]; then \
python -m pip install --user "$(head bdist_name)[fms-accel]"; \
python -m fms_acceleration.cli install fms_acceleration_peft; \
python -m fms_acceleration.cli install fms_acceleration_foak; \
python -m fms_acceleration.cli install fms_acceleration_aadp; \
python -m fms_acceleration.cli install fms_acceleration_moe; \
python -m fms_acceleration.cli install fms_acceleration_odm; \
fi

RUN if [[ "${ENABLE_AIM}" == "true" ]]; then \
Expand Down
3 changes: 2 additions & 1 deletion build/nvcr.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ RUN if [[ "${ENABLE_FMS_ACCELERATION}" == "true" ]]; then \
python -m fms_acceleration.cli install fms_acceleration_peft && \
python -m fms_acceleration.cli install fms_acceleration_foak && \
python -m fms_acceleration.cli install fms_acceleration_aadp && \
python -m fms_acceleration.cli install fms_acceleration_moe; \
python -m fms_acceleration.cli install fms_acceleration_moe && \
python -m fms_acceleration.cli install fms_acceleration_odm; \
fi

RUN if [[ "${ENABLE_ALORA}" == "true" ]]; then \
Expand Down
68 changes: 66 additions & 2 deletions docs/advanced-data-preprocessing.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Our library also supports a powerful data processing backend which can be used b
1. Creating custom data processing pipeline for the datasets.
1. Combining multiple datasets into one, even if they have different formats.
1. Mixing datasets as required and sampling each dataset with different weights.
1. Dynamically mixing datasets online based on training signals through fms_acceleration_odm plugin.

These things are supported via what we call a [`data_config`](#data-config) which can be passed as an argument to sft trainer.

Expand Down Expand Up @@ -34,7 +35,7 @@ process the datasets. Users can currently pass both YAML or JSON based configura
The data config schema is designed to define datasets and their processing strategies in a structured way.

It consists of the following top-level keys:
- `datapreprocessor`: Defines global data processing parameters, such as the type (`default`), sampling stopping strategy (`all_exhausted` or `first_exhausted`), and sampling seed for reproducibility.
- `datapreprocessor`: Defines global data processing parameters, such as the type (`default` or `odm`), sampling stopping strategy (`all_exhausted` or `first_exhausted`), and sampling seed for reproducibility.
- `datasets`: A list of dataset configurations, each describing the dataset name, paths, optional builders, sampling ratios, and data handlers.

At the top level, the data config schema looks like this:
Expand Down Expand Up @@ -129,11 +130,21 @@ definitions:
Users can create a data config file in any of YAML or JSON format they choose (we provide examples of YAML for ease of use). The file should follow the schema outlined above with the following parameters:

`datapreprocessor`:
- `type` (optional, str): Type of data preprocessor, `default` is currently the only supported type.
- `type` (optional, str): Type of data preprocessor, `default` and `odm` are the two types supported. Use of `odm` requires [installation](./tuning-techniques.md#fms-acceleration) of `fms_acceleration_odm` package.
- `streaming` (optional, bool): Stream datasets using [IterableDatasets](https://huggingface.co/docs/datasets/v3.2.0/en/package_reference/main_classes#datasets.IterableDataset).
- `sampling_stopping_strategy` (optional, str): Dataset interleave stopping strategy in case of choosing to mix multiple datasets by weight, supported values are [`all_exhausted` or `first_exhausted`](https://huggingface.co/docs/datasets/v3.2.0/en/package_reference/main_classes#datasets.interleave_datasets.stopping_strategy), defaults to `all_exhausted`.
- `seed` (optional, int): [seed](https://huggingface.co/docs/datasets/v3.2.0/en/package_reference/main_classes#datasets.interleave_datasets.seed) to use for interleaving datasets, for reproducibility choose same value, defaults to 42.
- `chat_template` (optional, str): pass `chat_template` via data_config for multi-turn data, replaces existing default chat template.
- `odm` (optional): if `type` is odm, this field is required to be specific to provide configuration for online data mixing.

`odm` config has the following fields and is required when `datapreprocessor` `type` is `odm`.

`odm`:
`update_interval` (optional, int, defaults to `1`): Multi-Armed Bandit (MAB) is used to learn from the training signals and then provide mixing probabilities across datasets. `update_interval` defines the frequency of updating the MAB with training signals in terms of step count.
`sampling_interval` (optional, int, defaults to `1`): Defines the frequency of choosing a dataset to sample from through MAB. The value is provided in terms of sample count.
`reward_type` (optional, str, defaults to `entropy`): Type of reward to be used to update MAB. Currently supported rewards are `train_loss`, `validation_loss`, `entropy`, `entropy3_varent1`, `entropy_last_token`, `gradnorm`. More details can be found [here](https://github.com/foundation-model-stack/fms-acceleration/tree/main/plugins/online-data-mixing#rewards).
`gamma` (optional, int, defaults to `0.1`): MAB hyper-parameter which is similar to exploration factor.
`eta` (optional, int, defaults to `0.1`): MAB hyper-parameter which is similar to learning rate.

`datasets` (list):
- `name` (optional, str): A unique identifier for the dataset.
Expand Down Expand Up @@ -192,6 +203,59 @@ We also allow users to pass a [`seed`](https://huggingface.co/docs/datasets/v3.2

Note: If a user specifies data sampling they can expect the datasets to be mixed and individual samples in the dataset to not be broken unless the max_seq_len argument is smaller than the length of individual samples in the dataset

### Online Data Mixing
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make ODM a separate document so its easy for users to find.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, I have made it into a new doc and changed references accordingly.

Dataset mixing can be dynamic in nature that adapts online during the training based on the training signals. We provide this feature through fms_acceleration_odm plugin and more details can be found [here](https://github.com/foundation-model-stack/fms-acceleration/tree/main/plugins/online-data-mixing).

#### How to Use

`dataprocessor` `type` has to be set to `odm` and then `odm` config should be provided in the `odm` section of the data config file. An example is shown below:

```yaml
dataprocessor:
type: odm
odm:
update_interval: 1 # update every step
sampling_interval: 1 # sample category for every sample
reward_type: validation_loss # uses eval loss of each dataset as reward
gamma: 0.1 # MAB hyper-parameter
eta: 0.2 # MAB hyper-parameter
```

Here `update_interval` is set to `1` which is to update MAB on every step with validation loss as reward across the datasets. `sampling_interval` is set to `1` which is to choose a dataset to sample for every sample. `reward_type` is set to `validation_loss` to use validation loss across datasets as a training signal to reward MAB decisions during training. Example `datasets` section can look like below:

```yaml
datasets:
- name: dataset_1
split:
train: 0.8
validation: 0.2
data_paths:
- "FILE_PATH"
data_handlers:
- name: tokenize_and_apply_input_masking
arguments:
remove_columns: all
batched: false
fn_kwargs:
input_column_name: input
output_column_name: output
- name: dataset_2
split:
train: 0.9
validation: 0.1
data_paths:
- "FILE_PATH"
data_handlers:
- name: tokenize_and_apply_input_masking
arguments:
remove_columns: all
batched: false
fn_kwargs:
input_column_name: input
output_column_name: output
```
As you notice, `validation` under `split` is provided for each of the datasets and is necessary to be provided since the `reward_type` is `validation_loss` which requires validation datasets to be available. Same applies to the following rewards: `validation_loss`, `entropy`, `entropy3_varent1`, and `entropy_last_token`. While reward_types `train_loss` and `gradnorm` do not require validation split.

### Dataset Splitting

In addition to [sampling and mixing](#data-mixing), our library supports **dataset splitting**, which allows users to split a dataset into training and validation splits using the `split` field in the dataset config.
Expand Down
1 change: 1 addition & 0 deletions docs/tuning-techniques.md
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,7 @@ The list of configurations for various `fms_acceleration` plugins:
- `--multipack`: technique for *multi-gpu training* to balance out number of tokens processed in each device, to minimize waiting time.
- [fast_moe_config](./tuning/config/acceleration_configs/fast_moe.py) (experimental):
- `--fast_moe`: trains MoE models in parallel with [Scatter MoE kernels](https://github.com/foundation-model-stack/fms-acceleration/tree/main/plugins/accelerated-moe#fms-acceleration-for-mixture-of-experts), increasing throughput and decreasing memory usage.
- [odm_config](./tuning/config/acceleration_configs/odm.py) (experimental): See [advanced data preprocessing](./advanced-data-preprocessing.md#online-data-mixing) for usage with data_config. This plugin allows dynamically mixing datasets online during training adapting to training signals.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you want to link the pytorch poster link here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!


Notes:
* `quantized_lora_config` requires that it be used along with LoRA tuning technique. See [LoRA tuning section](https://github.com/foundation-model-stack/fms-hf-tuning/tree/main?tab=readme-ov-file#lora-tuning-example) on the LoRA parameters to pass.
Expand Down
3 changes: 3 additions & 0 deletions tests/artifacts/predefined_data_configs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@
DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_AND_SPLIT_YAML = os.path.join(
PREDEFINED_DATA_CONFIGS, "multiple_datasets_with_sampling_and_split.yaml"
)
DATA_CONFIG_MULTIPLE_DATASETS_ODM_YAML = os.path.join(
PREDEFINED_DATA_CONFIGS, "multiple_datasets_with_odm.yaml"
)
DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_AND_SPLIT_YAML_2 = os.path.join(
PREDEFINED_DATA_CONFIGS, "multiple_datasets_with_sampling_and_split_2.yaml"
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
dataprocessor:
type: odm
sampling_stopping_strategy: first_exhausted # ignored
seed: 66
odm:
update_interval: 1 # update every step
sampling_interval: 1 # sample category for every sample
reward_type: validation_loss # uses eval loss of each dataset as reward
gamma: 0.1 # MAB hyper-parameter
eta: 0.2 # MAB hyper-parameter
datasets:
- name: dataset_1
split:
train: 0.8
validation: 0.2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a comment after the ratio that this will be overloaded as a reward dataset too in odm.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is not explained in the documentation we should do this too

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comments are added and same is explained in docs as well.

sampling: 0.3 # ignored
data_paths:
- "FILE_PATH"
data_handlers:
- name: tokenize_and_apply_input_masking
arguments:
remove_columns: all
batched: false
fn_kwargs:
input_column_name: input
output_column_name: output
- name: dataset_2
split:
train: 0.6
validation: 0.2
sampling: 0.4 # ignored
data_paths:
- "FILE_PATH"
data_handlers:
- name: tokenize_and_apply_input_masking
arguments:
remove_columns: all
batched: false
fn_kwargs:
input_column_name: input
output_column_name: output
- name: dataset_3
split:
train: 0.4
validation: 0.1
sampling: 0.3 # ignored
data_paths:
- "FILE_PATH"
data_handlers:
- name: tokenize_and_apply_input_masking
arguments:
remove_columns: all
batched: false
fn_kwargs:
input_column_name: input
output_column_name: output
- name: dataset_4
split:
train: 0.0
validation: 0.3 # ignored
data_paths:
- "FILE_PATH"
data_handlers:
- name: tokenize_and_apply_input_masking
arguments:
remove_columns: all
batched: false
fn_kwargs:
input_column_name: input
output_column_name: output
3 changes: 3 additions & 0 deletions tests/artifacts/testdata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL = os.path.join(
JSONL_DATA_DIR, "twitter_complaints_input_output.jsonl"
)
NESTFUL_DATA_INPUT_OUTPUT_JSONL = os.path.join(
JSONL_DATA_DIR, "nestful_100_samples_input_output.jsonl"
)
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW = os.path.join(
ARROW_DATA_DIR, "twitter_complaints_input_output.arrow"
)
Expand Down
Loading