Skip to content

Commit 4ec1340

Browse files
authored
feat: add online data mixing plugin (#612)
feat: add odm plugin Signed-off-by: Mehant Kammakomati <[email protected]> --------- Signed-off-by: Mehant Kammakomati <[email protected]>
1 parent f337875 commit 4ec1340

File tree

17 files changed

+632
-34
lines changed

17 files changed

+632
-34
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
- [Advanced Data Processing](./docs/advanced-data-preprocessing.md#data-config)
99
- [Guidelines on supported data formats](./docs/advanced-data-preprocessing.md#use-cases-supported-via-command-line-argument-training_data_path)
1010
- [Offline data processing](#offline-data-preprocessing)
11+
- [Online data mixing](./docs/online-data-mixing.md)
1112
- [Additional Frameworks](#additional-frameworks)
1213
- [Inference](#inference)
1314
- [Validation](#validation)

build/Dockerfile

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,12 +165,14 @@ RUN --mount=type=cache,target=/home/${USER}/.cache/pip,uid=${USER_UID} \
165165
# fms_acceleration_foak = Fused LoRA and triton kernels
166166
# fms_acceleration_aadp = Padding-Free Flash Attention Computation
167167
# fms_acceleration_moe = Parallelized Mixture of Experts
168+
# fms_acceleration_odm = Online Data Mixing
168169
RUN if [[ "${ENABLE_FMS_ACCELERATION}" == "true" ]]; then \
169170
python -m pip install --user "$(head bdist_name)[fms-accel]"; \
170171
python -m fms_acceleration.cli install fms_acceleration_peft; \
171172
python -m fms_acceleration.cli install fms_acceleration_foak; \
172173
python -m fms_acceleration.cli install fms_acceleration_aadp; \
173174
python -m fms_acceleration.cli install fms_acceleration_moe; \
175+
python -m fms_acceleration.cli install fms_acceleration_odm; \
174176
fi
175177

176178
RUN if [[ "${ENABLE_AIM}" == "true" ]]; then \

build/nvcr.Dockerfile

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@ RUN if [[ "${ENABLE_FMS_ACCELERATION}" == "true" ]]; then \
5757
python -m fms_acceleration.cli install fms_acceleration_peft && \
5858
python -m fms_acceleration.cli install fms_acceleration_foak && \
5959
python -m fms_acceleration.cli install fms_acceleration_aadp && \
60-
python -m fms_acceleration.cli install fms_acceleration_moe; \
60+
python -m fms_acceleration.cli install fms_acceleration_moe && \
61+
python -m fms_acceleration.cli install fms_acceleration_odm; \
6162
fi
6263

6364
RUN if [[ "${ENABLE_ALORA}" == "true" ]]; then \

docs/advanced-data-preprocessing.md

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ process the datasets. Users can currently pass both YAML or JSON based configura
3434
The data config schema is designed to define datasets and their processing strategies in a structured way.
3535

3636
It consists of the following top-level keys:
37-
- `datapreprocessor`: Defines global data processing parameters, such as the type (`default`), sampling stopping strategy (`all_exhausted` or `first_exhausted`), and sampling seed for reproducibility.
37+
- `datapreprocessor`: Defines global data processing parameters, such as the type (`default` or `odm`), sampling stopping strategy (`all_exhausted` or `first_exhausted`), and sampling seed for reproducibility.
3838
- `datasets`: A list of dataset configurations, each describing the dataset name, paths, optional builders, sampling ratios, and data handlers.
3939

4040
At the top level, the data config schema looks like this:
@@ -129,11 +129,29 @@ definitions:
129129
Users can create a data config file in any of YAML or JSON format they choose (we provide examples of YAML for ease of use). The file should follow the schema outlined above with the following parameters:
130130
131131
`datapreprocessor`:
132-
- `type` (optional, str): Type of data preprocessor, `default` is currently the only supported type.
132+
- `type` (optional, str): Type of data preprocessor, `default` and `odm` are the two types supported. Use of `odm` requires [installation](./tuning-techniques.md#fms-acceleration) of `fms_acceleration_odm` package.
133133
- `streaming` (optional, bool): Stream datasets using [IterableDatasets](https://huggingface.co/docs/datasets/v3.2.0/en/package_reference/main_classes#datasets.IterableDataset).
134134
- `sampling_stopping_strategy` (optional, str): Dataset interleave stopping strategy in case of choosing to mix multiple datasets by weight, supported values are [`all_exhausted` or `first_exhausted`](https://huggingface.co/docs/datasets/v3.2.0/en/package_reference/main_classes#datasets.interleave_datasets.stopping_strategy), defaults to `all_exhausted`.
135135
- `seed` (optional, int): [seed](https://huggingface.co/docs/datasets/v3.2.0/en/package_reference/main_classes#datasets.interleave_datasets.seed) to use for interleaving datasets, for reproducibility choose same value, defaults to 42.
136136
- `chat_template` (optional, str): pass `chat_template` via data_config for multi-turn data, replaces existing default chat template.
137+
- `odm` (optional): if `type` is odm, this field is required to be specific to provide configuration for online data mixing.
138+
139+
Data handlers are customizable components within the data config that allow users to preprocess or manipulate individual datasets. We use [Hugging Face Map API](https://huggingface.co/docs/datasets/en/process#map) to apply these routines.
140+
These functions can process the dataset in any way users require and the `list` of data handlers specified for each dataset are applied in order.
141+
Each data handler has:
142+
- `name`: The handler's unique identifier.
143+
- `arguments`: A dictionary of parameters specific to the handler.
144+
145+
#### Online data mixing section
146+
147+
`odm` config has the following fields and is required when `datapreprocessor` `type` is `odm`.
148+
149+
`odm`:
150+
`update_interval` (optional, int, defaults to `1`): Multi-Armed Bandit (MAB) is used to learn from the training signals and then provide mixing probabilities across datasets. `update_interval` defines the frequency of updating the MAB with training signals in terms of step count.
151+
`sampling_interval` (optional, int, defaults to `1`): Defines the frequency of choosing a dataset to sample from through MAB. The value is provided in terms of sample count.
152+
`reward_type` (optional, str, defaults to `entropy`): Type of reward to be used to update MAB. Currently supported rewards are `train_loss`, `validation_loss`, `entropy`, `entropy3_varent1`, `entropy_last_token`, `gradnorm`. More details can be found [here](https://github.com/foundation-model-stack/fms-acceleration/tree/main/plugins/online-data-mixing#rewards).
153+
`gamma` (optional, int, defaults to `0.1`): MAB hyper-parameter which is similar to exploration factor.
154+
`eta` (optional, int, defaults to `0.1`): MAB hyper-parameter which is similar to learning rate.
137155

138156
`datasets` (list):
139157
- `name` (optional, str): A unique identifier for the dataset.
@@ -143,11 +161,6 @@ Users can create a data config file in any of YAML or JSON format they choose (w
143161
- `split` (optional, dict[str: float]): Defines how to split the dataset into training and validation sets. Requires both `train` and `validation` keys.
144162
- `data_handlers` (optional, list): A list of data handler configurations which preprocess the dataset.
145163

146-
Data handlers are customizable components within the data config that allow users to preprocess or manipulate individual datasets. We use [Hugging Face Map API](https://huggingface.co/docs/datasets/en/process#map) to apply these routines.
147-
These functions can process the dataset in any way users require and the `list` of data handlers specified for each dataset are applied in order.
148-
Each data handler has:
149-
- `name`: The handler's unique identifier.
150-
- `arguments`: A dictionary of parameters specific to the handler.
151164

152165
We do provide some sample `data_configs` here, [predefined_data_configs](../tests/artifacts/predefined_data_configs/).
153166

@@ -192,6 +205,7 @@ We also allow users to pass a [`seed`](https://huggingface.co/docs/datasets/v3.2
192205

193206
Note: If a user specifies data sampling they can expect the datasets to be mixed and individual samples in the dataset to not be broken unless the max_seq_len argument is smaller than the length of individual samples in the dataset
194207

208+
195209
### Dataset Splitting
196210

197211
In addition to [sampling and mixing](#data-mixing), our library supports **dataset splitting**, which allows users to split a dataset into training and validation splits using the `split` field in the dataset config.

docs/tuning-techniques.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -470,8 +470,9 @@ The list of configurations for various `fms_acceleration` plugins:
470470
- `--multipack`: technique for *multi-gpu training* to balance out number of tokens processed in each device, to minimize waiting time.
471471
- [fast_moe_config](./tuning/config/acceleration_configs/fast_moe.py) (experimental):
472472
- `--fast_moe`: trains MoE models in parallel with [Scatter MoE kernels](https://github.com/foundation-model-stack/fms-acceleration/tree/main/plugins/accelerated-moe#fms-acceleration-for-mixture-of-experts), increasing throughput and decreasing memory usage.
473+
- [odm_config](./tuning/config/acceleration_configs/odm.py) (experimental): See [online data mixing](./online-data-mixing.md) and [PyTorch conf poster](https://static.sched.com/hosted_files/pytorchconference/70/PyTorch%20Native%20Online%20Dynamic%20Reward%20Based%20Data%20Mixing%20Framework.pdf) for usage with data_config. This plugin allows dynamically mixing datasets online during training adapting to training signals.
473474

474-
Notes:
475+
Notes:
475476
* `quantized_lora_config` requires that it be used along with LoRA tuning technique. See [LoRA tuning section](https://github.com/foundation-model-stack/fms-hf-tuning/tree/main?tab=readme-ov-file#lora-tuning-example) on the LoRA parameters to pass.
476477
* When setting `--auto_gptq triton_v2` plus note to also pass `--torch_dtype float16` and `--fp16`, or an exception will be raised. This is because these kernels only support this dtype.
477478
* When using `fused_ops_and_kernels` together with `quantized_lora_config`,

tests/artifacts/predefined_data_configs/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@
3434
DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_AND_SPLIT_YAML = os.path.join(
3535
PREDEFINED_DATA_CONFIGS, "multiple_datasets_with_sampling_and_split.yaml"
3636
)
37+
DATA_CONFIG_MULTIPLE_DATASETS_ODM_YAML = os.path.join(
38+
PREDEFINED_DATA_CONFIGS, "multiple_datasets_with_odm.yaml"
39+
)
3740
DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_AND_SPLIT_YAML_2 = os.path.join(
3841
PREDEFINED_DATA_CONFIGS, "multiple_datasets_with_sampling_and_split_2.yaml"
3942
)
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
dataprocessor:
2+
type: odm
3+
sampling_stopping_strategy: first_exhausted # ignored
4+
seed: 66
5+
odm:
6+
update_interval: 1 # update every step
7+
sampling_interval: 1 # sample category for every sample
8+
reward_type: validation_loss # uses eval loss of each dataset as reward
9+
gamma: 0.1 # MAB hyper-parameter
10+
eta: 0.2 # MAB hyper-parameter
11+
datasets:
12+
- name: dataset_1
13+
split:
14+
train: 0.8
15+
validation: 0.2 # validation set is also used in ODM reward computation when reward_type is validation_loss.
16+
sampling: 0.3 # ignored
17+
data_paths:
18+
- "FILE_PATH"
19+
data_handlers:
20+
- name: tokenize_and_apply_input_masking
21+
arguments:
22+
remove_columns: all
23+
batched: false
24+
fn_kwargs:
25+
input_column_name: input
26+
output_column_name: output
27+
- name: dataset_2
28+
split:
29+
train: 0.6
30+
validation: 0.2 # validation set is also used in ODM reward computation when reward_type is validation_loss.
31+
sampling: 0.4 # ignored
32+
data_paths:
33+
- "FILE_PATH"
34+
data_handlers:
35+
- name: tokenize_and_apply_input_masking
36+
arguments:
37+
remove_columns: all
38+
batched: false
39+
fn_kwargs:
40+
input_column_name: input
41+
output_column_name: output
42+
- name: dataset_3
43+
split:
44+
train: 0.4
45+
validation: 0.1 # validation set is also used in ODM reward computation when reward_type is validation_loss.
46+
sampling: 0.3 # ignored
47+
data_paths:
48+
- "FILE_PATH"
49+
data_handlers:
50+
- name: tokenize_and_apply_input_masking
51+
arguments:
52+
remove_columns: all
53+
batched: false
54+
fn_kwargs:
55+
input_column_name: input
56+
output_column_name: output
57+
- name: dataset_4
58+
split:
59+
train: 0.0
60+
validation: 0.3 # validation set is also used in ODM reward computation when reward_type is validation_loss.
61+
data_paths:
62+
- "FILE_PATH"
63+
data_handlers:
64+
- name: tokenize_and_apply_input_masking
65+
arguments:
66+
remove_columns: all
67+
batched: false
68+
fn_kwargs:
69+
input_column_name: input
70+
output_column_name: output

tests/artifacts/testdata/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@
4444
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL = os.path.join(
4545
JSONL_DATA_DIR, "twitter_complaints_input_output.jsonl"
4646
)
47+
NESTFUL_DATA_INPUT_OUTPUT_JSONL = os.path.join(
48+
JSONL_DATA_DIR, "nestful_100_samples_input_output.jsonl"
49+
)
4750
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW = os.path.join(
4851
ARROW_DATA_DIR, "twitter_complaints_input_output.arrow"
4952
)

0 commit comments

Comments
 (0)