Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
c9a519e
feat: add odm plugin user facing config
kmehant Sep 22, 2025
d63eff0
feat: add odm plugin
kmehant Sep 22, 2025
6f95e56
feat: add odm plugin
kmehant Sep 22, 2025
8370997
feat: add odm plugin
kmehant Sep 22, 2025
88f68b1
feat: add odm plugin
kmehant Sep 22, 2025
a1b485b
feat: add odm plugin
kmehant Sep 22, 2025
be3afbd
feat: add odm plugin
kmehant Sep 22, 2025
6b0d4cd
feat: add odm plugin
kmehant Sep 23, 2025
ad35a69
feat: add odm plugin
kmehant Sep 23, 2025
245e2a2
feat: add odm plugin
kmehant Sep 23, 2025
b5cd791
feat: add odm_plugin
kmehant Sep 24, 2025
e8d748b
feat: add odm_plugin
kmehant Sep 24, 2025
3ebb844
feat: add odm_plugin
kmehant Sep 24, 2025
5953bfb
feat: add odm_plugin
kmehant Sep 24, 2025
9a8804a
feat: add odm_plugin
kmehant Sep 24, 2025
89b95a8
feat: add odm_plugin
kmehant Sep 24, 2025
bb5bc0d
feat: add odm_plugin
kmehant Sep 24, 2025
32f2982
feat: add odm_plugin
kmehant Sep 24, 2025
5756ad5
feat: add odm_plugin
kmehant Sep 24, 2025
8b02967
feat: add odm_plugin
kmehant Sep 24, 2025
2086d1e
feat: add odm_plugin
kmehant Sep 24, 2025
dadb6f0
feat: add odm_plugin
kmehant Sep 24, 2025
6ff740b
feat: add odm_plugin
kmehant Sep 24, 2025
beda5df
feat: add odm_plugin
kmehant Sep 24, 2025
304de74
fix: code refactor
kmehant Sep 25, 2025
9d69c85
fix: code refactor
kmehant Sep 25, 2025
a35ca56
fix: code refactor
kmehant Sep 25, 2025
0865130
fix: code refactor
kmehant Sep 25, 2025
c6f333e
fix: code refactor
kmehant Sep 25, 2025
c9945d1
fix: code refactor
kmehant Sep 25, 2025
12ed5fe
fix: code refactor
kmehant Sep 25, 2025
9fc8238
fix: function and argument types
kmehant Sep 29, 2025
0bff8b4
fix: lint errors
kmehant Sep 29, 2025
dc03331
fix: unit tests
kmehant Sep 29, 2025
d5db867
docs: add docs
kmehant Sep 29, 2025
e206f23
fix: refactor
kmehant Oct 7, 2025
81d5917
fix: refactor
kmehant Oct 7, 2025
ae7698e
fix: refactor
kmehant Oct 7, 2025
b72e421
feat: resume functionality
kmehant Oct 7, 2025
86eb54a
feat: resume functionality
kmehant Oct 7, 2025
5dd066f
feat: resume functionality
kmehant Oct 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
- [Advanced Data Processing](./docs/advanced-data-preprocessing.md#data-config)
- [Guidelines on supported data formats](./docs/advanced-data-preprocessing.md#use-cases-supported-via-command-line-argument-training_data_path)
- [Offline data processing](#offline-data-preprocessing)
- [Online data mixing](./docs/online-data-mixing.md)
- [Additional Frameworks](#additional-frameworks)
- [Inference](#inference)
- [Validation](#validation)
Expand Down
2 changes: 2 additions & 0 deletions build/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -165,12 +165,14 @@ RUN --mount=type=cache,target=/home/${USER}/.cache/pip,uid=${USER_UID} \
# fms_acceleration_foak = Fused LoRA and triton kernels
# fms_acceleration_aadp = Padding-Free Flash Attention Computation
# fms_acceleration_moe = Parallelized Mixture of Experts
# fms_acceleration_odm = Online Data Mixing
RUN if [[ "${ENABLE_FMS_ACCELERATION}" == "true" ]]; then \
python -m pip install --user "$(head bdist_name)[fms-accel]"; \
python -m fms_acceleration.cli install fms_acceleration_peft; \
python -m fms_acceleration.cli install fms_acceleration_foak; \
python -m fms_acceleration.cli install fms_acceleration_aadp; \
python -m fms_acceleration.cli install fms_acceleration_moe; \
python -m fms_acceleration.cli install fms_acceleration_odm; \
fi

RUN if [[ "${ENABLE_AIM}" == "true" ]]; then \
Expand Down
3 changes: 2 additions & 1 deletion build/nvcr.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ RUN if [[ "${ENABLE_FMS_ACCELERATION}" == "true" ]]; then \
python -m fms_acceleration.cli install fms_acceleration_peft && \
python -m fms_acceleration.cli install fms_acceleration_foak && \
python -m fms_acceleration.cli install fms_acceleration_aadp && \
python -m fms_acceleration.cli install fms_acceleration_moe; \
python -m fms_acceleration.cli install fms_acceleration_moe && \
python -m fms_acceleration.cli install fms_acceleration_odm; \
fi

RUN if [[ "${ENABLE_ALORA}" == "true" ]]; then \
Expand Down
28 changes: 21 additions & 7 deletions docs/advanced-data-preprocessing.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ process the datasets. Users can currently pass both YAML or JSON based configura
The data config schema is designed to define datasets and their processing strategies in a structured way.

It consists of the following top-level keys:
- `datapreprocessor`: Defines global data processing parameters, such as the type (`default`), sampling stopping strategy (`all_exhausted` or `first_exhausted`), and sampling seed for reproducibility.
- `datapreprocessor`: Defines global data processing parameters, such as the type (`default` or `odm`), sampling stopping strategy (`all_exhausted` or `first_exhausted`), and sampling seed for reproducibility.
- `datasets`: A list of dataset configurations, each describing the dataset name, paths, optional builders, sampling ratios, and data handlers.

At the top level, the data config schema looks like this:
Expand Down Expand Up @@ -129,11 +129,29 @@ definitions:
Users can create a data config file in any of YAML or JSON format they choose (we provide examples of YAML for ease of use). The file should follow the schema outlined above with the following parameters:

`datapreprocessor`:
- `type` (optional, str): Type of data preprocessor, `default` is currently the only supported type.
- `type` (optional, str): Type of data preprocessor, `default` and `odm` are the two types supported. Use of `odm` requires [installation](./tuning-techniques.md#fms-acceleration) of `fms_acceleration_odm` package.
- `streaming` (optional, bool): Stream datasets using [IterableDatasets](https://huggingface.co/docs/datasets/v3.2.0/en/package_reference/main_classes#datasets.IterableDataset).
- `sampling_stopping_strategy` (optional, str): Dataset interleave stopping strategy in case of choosing to mix multiple datasets by weight, supported values are [`all_exhausted` or `first_exhausted`](https://huggingface.co/docs/datasets/v3.2.0/en/package_reference/main_classes#datasets.interleave_datasets.stopping_strategy), defaults to `all_exhausted`.
- `seed` (optional, int): [seed](https://huggingface.co/docs/datasets/v3.2.0/en/package_reference/main_classes#datasets.interleave_datasets.seed) to use for interleaving datasets, for reproducibility choose same value, defaults to 42.
- `chat_template` (optional, str): pass `chat_template` via data_config for multi-turn data, replaces existing default chat template.
- `odm` (optional): if `type` is odm, this field is required to be specific to provide configuration for online data mixing.

Data handlers are customizable components within the data config that allow users to preprocess or manipulate individual datasets. We use [Hugging Face Map API](https://huggingface.co/docs/datasets/en/process#map) to apply these routines.
These functions can process the dataset in any way users require and the `list` of data handlers specified for each dataset are applied in order.
Each data handler has:
- `name`: The handler's unique identifier.
- `arguments`: A dictionary of parameters specific to the handler.

#### Online data mixing section

`odm` config has the following fields and is required when `datapreprocessor` `type` is `odm`.

`odm`:
`update_interval` (optional, int, defaults to `1`): Multi-Armed Bandit (MAB) is used to learn from the training signals and then provide mixing probabilities across datasets. `update_interval` defines the frequency of updating the MAB with training signals in terms of step count.
`sampling_interval` (optional, int, defaults to `1`): Defines the frequency of choosing a dataset to sample from through MAB. The value is provided in terms of sample count.
`reward_type` (optional, str, defaults to `entropy`): Type of reward to be used to update MAB. Currently supported rewards are `train_loss`, `validation_loss`, `entropy`, `entropy3_varent1`, `entropy_last_token`, `gradnorm`. More details can be found [here](https://github.com/foundation-model-stack/fms-acceleration/tree/main/plugins/online-data-mixing#rewards).
`gamma` (optional, int, defaults to `0.1`): MAB hyper-parameter which is similar to exploration factor.
`eta` (optional, int, defaults to `0.1`): MAB hyper-parameter which is similar to learning rate.

`datasets` (list):
- `name` (optional, str): A unique identifier for the dataset.
Expand All @@ -143,11 +161,6 @@ Users can create a data config file in any of YAML or JSON format they choose (w
- `split` (optional, dict[str: float]): Defines how to split the dataset into training and validation sets. Requires both `train` and `validation` keys.
- `data_handlers` (optional, list): A list of data handler configurations which preprocess the dataset.

Data handlers are customizable components within the data config that allow users to preprocess or manipulate individual datasets. We use [Hugging Face Map API](https://huggingface.co/docs/datasets/en/process#map) to apply these routines.
These functions can process the dataset in any way users require and the `list` of data handlers specified for each dataset are applied in order.
Each data handler has:
- `name`: The handler's unique identifier.
- `arguments`: A dictionary of parameters specific to the handler.

We do provide some sample `data_configs` here, [predefined_data_configs](../tests/artifacts/predefined_data_configs/).

Expand Down Expand Up @@ -192,6 +205,7 @@ We also allow users to pass a [`seed`](https://huggingface.co/docs/datasets/v3.2

Note: If a user specifies data sampling they can expect the datasets to be mixed and individual samples in the dataset to not be broken unless the max_seq_len argument is smaller than the length of individual samples in the dataset


### Dataset Splitting

In addition to [sampling and mixing](#data-mixing), our library supports **dataset splitting**, which allows users to split a dataset into training and validation splits using the `split` field in the dataset config.
Expand Down
3 changes: 2 additions & 1 deletion docs/tuning-techniques.md
Original file line number Diff line number Diff line change
Expand Up @@ -470,8 +470,9 @@ The list of configurations for various `fms_acceleration` plugins:
- `--multipack`: technique for *multi-gpu training* to balance out number of tokens processed in each device, to minimize waiting time.
- [fast_moe_config](./tuning/config/acceleration_configs/fast_moe.py) (experimental):
- `--fast_moe`: trains MoE models in parallel with [Scatter MoE kernels](https://github.com/foundation-model-stack/fms-acceleration/tree/main/plugins/accelerated-moe#fms-acceleration-for-mixture-of-experts), increasing throughput and decreasing memory usage.
- [odm_config](./tuning/config/acceleration_configs/odm.py) (experimental): See [online data mixing](./online-data-mixing.md) and [PyTorch conf poster](https://static.sched.com/hosted_files/pytorchconference/70/PyTorch%20Native%20Online%20Dynamic%20Reward%20Based%20Data%20Mixing%20Framework.pdf) for usage with data_config. This plugin allows dynamically mixing datasets online during training adapting to training signals.

Notes:
Notes:
* `quantized_lora_config` requires that it be used along with LoRA tuning technique. See [LoRA tuning section](https://github.com/foundation-model-stack/fms-hf-tuning/tree/main?tab=readme-ov-file#lora-tuning-example) on the LoRA parameters to pass.
* When setting `--auto_gptq triton_v2` plus note to also pass `--torch_dtype float16` and `--fp16`, or an exception will be raised. This is because these kernels only support this dtype.
* When using `fused_ops_and_kernels` together with `quantized_lora_config`,
Expand Down
3 changes: 3 additions & 0 deletions tests/artifacts/predefined_data_configs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@
DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_AND_SPLIT_YAML = os.path.join(
PREDEFINED_DATA_CONFIGS, "multiple_datasets_with_sampling_and_split.yaml"
)
DATA_CONFIG_MULTIPLE_DATASETS_ODM_YAML = os.path.join(
PREDEFINED_DATA_CONFIGS, "multiple_datasets_with_odm.yaml"
)
DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_AND_SPLIT_YAML_2 = os.path.join(
PREDEFINED_DATA_CONFIGS, "multiple_datasets_with_sampling_and_split_2.yaml"
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
dataprocessor:
type: odm
sampling_stopping_strategy: first_exhausted # ignored
seed: 66
odm:
update_interval: 1 # update every step
sampling_interval: 1 # sample category for every sample
reward_type: validation_loss # uses eval loss of each dataset as reward
gamma: 0.1 # MAB hyper-parameter
eta: 0.2 # MAB hyper-parameter
datasets:
- name: dataset_1
split:
train: 0.8
validation: 0.2 # validation set is also used in ODM reward computation when reward_type is validation_loss.
sampling: 0.3 # ignored
data_paths:
- "FILE_PATH"
data_handlers:
- name: tokenize_and_apply_input_masking
arguments:
remove_columns: all
batched: false
fn_kwargs:
input_column_name: input
output_column_name: output
- name: dataset_2
split:
train: 0.6
validation: 0.2 # validation set is also used in ODM reward computation when reward_type is validation_loss.
sampling: 0.4 # ignored
data_paths:
- "FILE_PATH"
data_handlers:
- name: tokenize_and_apply_input_masking
arguments:
remove_columns: all
batched: false
fn_kwargs:
input_column_name: input
output_column_name: output
- name: dataset_3
split:
train: 0.4
validation: 0.1 # validation set is also used in ODM reward computation when reward_type is validation_loss.
sampling: 0.3 # ignored
data_paths:
- "FILE_PATH"
data_handlers:
- name: tokenize_and_apply_input_masking
arguments:
remove_columns: all
batched: false
fn_kwargs:
input_column_name: input
output_column_name: output
- name: dataset_4
split:
train: 0.0
validation: 0.3 # validation set is also used in ODM reward computation when reward_type is validation_loss.
data_paths:
- "FILE_PATH"
data_handlers:
- name: tokenize_and_apply_input_masking
arguments:
remove_columns: all
batched: false
fn_kwargs:
input_column_name: input
output_column_name: output
3 changes: 3 additions & 0 deletions tests/artifacts/testdata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL = os.path.join(
JSONL_DATA_DIR, "twitter_complaints_input_output.jsonl"
)
NESTFUL_DATA_INPUT_OUTPUT_JSONL = os.path.join(
JSONL_DATA_DIR, "nestful_100_samples_input_output.jsonl"
)
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW = os.path.join(
ARROW_DATA_DIR, "twitter_complaints_input_output.arrow"
)
Expand Down
Loading