Skip to content

Commit ea2eb7a

Browse files
feat: add feature to split training dataset to train and validate via dataconfig (#560)
* feat: add train_test_split functionality via dataconfig Signed-off-by: yashasvi <yashasvi@ibm.com> * docs: add documentation for dataset split support Signed-off-by: yashasvi <yashasvi@ibm.com> * feat: add evaluation_strategy to TrainingArguments and minor fix Signed-off-by: yashasvi <yashasvi@ibm.com> --------- Signed-off-by: yashasvi <yashasvi@ibm.com>
1 parent c28a4ed commit ea2eb7a

File tree

10 files changed

+374
-70
lines changed

10 files changed

+374
-70
lines changed

docs/advanced-data-preprocessing.md

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ Users can create a data config file in any of YAML or JSON format they choose (w
119119
- `type` (optional, str): Type of data preprocessor, `default` is currently the only supported type.
120120
- `streaming` (optional, bool): Stream datasets using [IterableDatasets](https://huggingface.co/docs/datasets/v3.2.0/en/package_reference/main_classes#datasets.IterableDataset).
121121
- `sampling_stopping_strategy` (optional, str): Dataset interleave stopping strategy in case of choosing to mix multiple datasets by weight, supported values are [`all_exhausted` or `first_exhausted`](https://huggingface.co/docs/datasets/v3.2.0/en/package_reference/main_classes#datasets.interleave_datasets.stopping_strategy), defaults to `all_exhausted`.
122-
- `sampling_seed` (optional, int): [Sampling seed](https://huggingface.co/docs/datasets/v3.2.0/en/package_reference/main_classes#datasets.interleave_datasets.seed) to use for interleaving datasets, for reproducibility choose same value, defaults to 42.
122+
- `seed` (optional, int): [seed](https://huggingface.co/docs/datasets/v3.2.0/en/package_reference/main_classes#datasets.interleave_datasets.seed) to use for interleaving datasets, for reproducibility choose same value, defaults to 42.
123123
- `chat_template` (optional, str): pass `chat_template` via data_config for multi-turn data, replaces existing default chat template.
124124

125125
`datasets` (list):
@@ -129,6 +129,7 @@ Users can create a data config file in any of YAML or JSON format they choose (w
129129
- `rename_columns` (optional, dict[str:str]): Specifies a dictionary of columns to rename like `{"old_name": "new_name"}` at dataset load time. *Applied before `retain_columns` if both are specified*.
130130
- `retain_columns` (optional, list[str]): Specifies a list of columns to retain `["input_ids", "labels"]` every other column will be dropped at dataset load time. *Applied strictly after `rename_columns` if both are specified*.
131131
- `sampling` (optional, float): The sampling ratio (0.0 to 1.0) with which to sample a dataset in case of interleaving.
132+
- `split` (optional, dict[str: float]): Defines how to split the dataset into training and validation sets. Requires both `train` and `validation` keys.
132133
- `data_handlers` (optional, list): A list of data handler configurations which preprocess the dataset.
133134

134135
Data handlers are customizable components within the data config that allow users to preprocess or manipulate individual datasets. We use [Hugging Face Map API](https://huggingface.co/docs/datasets/en/process#map) to apply these routines.
@@ -184,6 +185,76 @@ We also allow users to pass a [`seed`](https://huggingface.co/docs/datasets/v3.2
184185

185186
Note: If a user specifies data sampling they can expect the datasets to be mixed and individual samples in the dataset to not be broken unless the max_seq_len argument is smaller than the length of individual samples in the dataset
186187

188+
### Dataset Splitting
189+
190+
In addition to [sampling and mixing](#data-mixing), our library supports **dataset splitting**, which allows users to split a dataset into training and validation splits using the `split` field in the dataset config.
191+
192+
This is especially useful when users want to split a single dataset (or multiple datasets) internally instead of supplying separate files for training and validation.
193+
194+
#### How to Use
195+
196+
The `split` field in each dataset config allows users to internally divide a dataset into `train` and `validation` sets using fractional ratios.
197+
198+
To use it, specify both `train` and `validation` ratios values under the `split` key for each dataset. Example:
199+
200+
```yaml
201+
datasets:
202+
- name: my_dataset
203+
split:
204+
train: 0.8
205+
validation: 0.2
206+
data_paths:
207+
- "path/to/data.jsonl"
208+
```
209+
210+
### Split Support for Streaming vs Non-Streaming Datasets
211+
212+
**Non-Streaming Datasets (`Dataset`, `DatasetDict`)**:
213+
- Supports arbitrary train/validation splits.
214+
- Both `train` and `validation` keys must be present under `split`.
215+
- The sum of `train + validation` must be in `(0, 1]`; less than 1.0 implies subset usage.
216+
- If no `split` is defined, the dataset is returned unchanged.
217+
218+
**Streaming Datasets (`IterableDataset`, `IterableDatasetDict`)**:
219+
- Only supports full splits:
220+
- Either `train: 1.0, validation: 0.0`
221+
- Or `train: 0.0, validation: 1.0`
222+
- Partial splits like `train: 0.8, validation: 0.2` are not supported and will raise a `NotImplementedError`.
223+
- If no `split` is defined, the dataset is returned unchanged.
224+
- Streaming behavior must be explicitly enabled via `dataprocessor.streaming: true`.
225+
226+
### Using Separate Files for Train and Validation Splits
227+
228+
If you want to use **separate files for training and validation**, you can define them as **separate dataset entries** in the `datasets` section of your config.
229+
In each entry:
230+
231+
- Specify the corresponding file in the `data_paths` field.
232+
- Set the `split` value to either `train: 1.0` or `validation: 1.0` as appropriate.
233+
234+
This allows you to fully control which file is used for which purpose, without relying on automatic or in-place splitting.
235+
236+
#### Example
237+
238+
```yaml
239+
datasets:
240+
- name: my_train_set
241+
split:
242+
train: 1.0
243+
data_paths:
244+
- "path/to/train.jsonl"
245+
- name: my_val_set
246+
split:
247+
validation: 1.0
248+
data_paths:
249+
- "path/to/val.jsonl"
250+
```
251+
252+
### **Note:**
253+
> - While passing a validation dataset via the command line is possible using the `validation_data_path` argument, **this argument is not compatible with `data_config`**. If you're using a `data_config`, define the validation set within it using a `split: validation: 1.0` entry instead as shown [here](#using-separate-files-for-train-and-validation-splits).
254+
> - Dataset splitting is performed based on the `split` configuration, supporting only `"train"` and `"validation"` splits. Support for a `"test"` split is not yet available.
255+
> - **Only the `"train"` split is sampled**, and **sampling is done after splitting**. This ensures that validation remains consistent and unbiased, while allowing training to be performed on a controlled subset if desired.
256+
> - **⚠️ Users must explicitly set the [`eval_strategy`](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments.eval_strategy) in the Trainer's arguments to a valid value (e.g., `"steps"` or `"epoch"`) for evaluation to run. Splitting the dataset alone does not trigger evaluation and will likely result in an error if `evaluation_strategy` is left unset.**
257+
187258
### Data Streaming
188259
Dataset streaming allows users to utilize the functionality of iterable datasets to pass in data piece by piece, avoiding memory constraints with large datasets for use-cases like extended pre-training.
189260

tests/artifacts/predefined_data_configs/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@
3131
DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML = os.path.join(
3232
PREDEFINED_DATA_CONFIGS, "multiple_datasets_with_sampling.yaml"
3333
)
34+
DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_AND_SPLIT_YAML = os.path.join(
35+
PREDEFINED_DATA_CONFIGS, "multiple_datasets_with_sampling_and_split.yaml"
36+
)
3437
DATA_CONFIG_MULTITURN_DATA_YAML = os.path.join(
3538
PREDEFINED_DATA_CONFIGS, "multi_turn_data_with_chat_template.yaml"
3639
)
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
dataprocessor:
2+
type: default
3+
sampling_stopping_strategy: first_exhausted
4+
seed: 66
5+
datasets:
6+
- name: dataset_1
7+
split:
8+
train: 0.8
9+
validation: 0.2
10+
sampling: 0.3
11+
data_paths:
12+
- "FILE_PATH"
13+
data_handlers:
14+
- name: tokenize_and_apply_input_masking
15+
arguments:
16+
remove_columns: all
17+
batched: false
18+
fn_kwargs:
19+
input_column_name: input
20+
output_column_name: output
21+
- name: dataset_2
22+
split:
23+
train: 0.6
24+
validation: 0.2
25+
sampling: 0.4
26+
data_paths:
27+
- "FILE_PATH"
28+
data_handlers:
29+
- name: tokenize_and_apply_input_masking
30+
arguments:
31+
remove_columns: all
32+
batched: false
33+
fn_kwargs:
34+
input_column_name: input
35+
output_column_name: output
36+
- name: dataset_3
37+
split:
38+
train: 0.4
39+
validation: 0.1
40+
sampling: 0.3
41+
data_paths:
42+
- "FILE_PATH"
43+
data_handlers:
44+
- name: tokenize_and_apply_input_masking
45+
arguments:
46+
remove_columns: all
47+
batched: false
48+
fn_kwargs:
49+
input_column_name: input
50+
output_column_name: output

tests/data/test_data_preprocessing.py

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
)
3737
from tests.artifacts.predefined_data_configs import (
3838
DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML,
39+
DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_AND_SPLIT_YAML,
3940
DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML,
4041
DATA_CONFIG_MULTITURN_DATA_YAML,
4142
DATA_CONFIG_PRETOKENIZE_DATA_YAML,
@@ -1459,6 +1460,68 @@ def test_process_dataconfig_multiple_datasets_datafiles_sampling(
14591460
)
14601461

14611462

1463+
@pytest.mark.parametrize(
1464+
"datafiles, datasetconfigname",
1465+
[
1466+
(
1467+
[
1468+
[
1469+
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET,
1470+
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET,
1471+
],
1472+
[
1473+
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON,
1474+
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON,
1475+
],
1476+
[
1477+
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL,
1478+
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL,
1479+
],
1480+
],
1481+
DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_AND_SPLIT_YAML,
1482+
),
1483+
],
1484+
)
1485+
def test_process_dataconfig_multiple_datasets_datafiles_sampling_and_split(
1486+
datafiles, datasetconfigname
1487+
):
1488+
"""Ensure that multiple datasets with multiple files are formatted and validated correctly."""
1489+
with open(datasetconfigname, "r") as f:
1490+
yaml_content = yaml.safe_load(f)
1491+
yaml_content["datasets"][0]["data_paths"] = datafiles[0]
1492+
yaml_content["datasets"][1]["data_paths"] = datafiles[1]
1493+
yaml_content["datasets"][2]["data_paths"] = datafiles[2]
1494+
with tempfile.NamedTemporaryFile(
1495+
"w", delete=False, suffix=".yaml"
1496+
) as temp_yaml_file:
1497+
yaml.dump(yaml_content, temp_yaml_file)
1498+
temp_yaml_file_path = temp_yaml_file.name
1499+
data_args = configs.DataArguments(data_config_path=temp_yaml_file_path)
1500+
1501+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
1502+
1503+
TRAIN_ARGS = configs.TrainingArguments(
1504+
packing=False,
1505+
max_seq_length=1024,
1506+
output_dir="tmp",
1507+
)
1508+
(train_set, eval_set, _, _, _, _) = process_dataargs(
1509+
data_args=data_args, tokenizer=tokenizer, train_args=TRAIN_ARGS
1510+
)
1511+
1512+
assert isinstance(train_set, Dataset)
1513+
assert isinstance(eval_set, Dataset)
1514+
assert set(["input_ids", "attention_mask", "labels"]).issubset(
1515+
set(eval_set.column_names)
1516+
)
1517+
# training_data_path/validation_data_path args are not supported with data_config
1518+
with pytest.raises(ValueError):
1519+
data_args.training_data_path = "/tmp/some/path"
1520+
process_dataargs(
1521+
data_args=data_args, tokenizer=tokenizer, train_args=TRAIN_ARGS
1522+
)
1523+
1524+
14621525
@pytest.mark.parametrize(
14631526
"data_args, is_padding_free",
14641527
[
@@ -1690,7 +1753,7 @@ def test_process_dataset_configs(datafile, column_names, datasetconfigname):
16901753
tokenizer=tokenizer,
16911754
)
16921755
datasetconfig = [DataSetConfig(name=datasetconfigname, data_paths=[datafile])]
1693-
train_dataset = processor.process_dataset_configs(dataset_configs=datasetconfig)
1756+
train_dataset, _ = processor.process_dataset_configs(dataset_configs=datasetconfig)
16941757

16951758
assert isinstance(train_dataset, Dataset)
16961759
assert set(train_dataset.column_names) == column_names
@@ -1812,7 +1875,7 @@ def test_rename_and_select_dataset_columns(
18121875
name=datasetconfigname, data_paths=data_paths, data_handlers=handlers
18131876
)
18141877
]
1815-
train_dataset = processor.process_dataset_configs(dataset_configs=datasetconfig)
1878+
train_dataset, _ = processor.process_dataset_configs(dataset_configs=datasetconfig)
18161879

18171880
assert isinstance(train_dataset, Dataset)
18181881
assert set(train_dataset.column_names) == set(final)

tests/test_sft_trainer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1328,6 +1328,7 @@ def test_run_chat_style_ft_using_dataconfig(datafiles, dataconfigfile):
13281328
with tempfile.TemporaryDirectory() as tempdir:
13291329

13301330
data_args = copy.deepcopy(DATA_ARGS)
1331+
data_args.training_data_path = None
13311332
data_args.chat_template = "{% for message in messages['messages'] %}\
13321333
{% if message['role'] == 'user' %}{{ '<|user|>\n' + message['content'] + eos_token }}\
13331334
{% elif message['role'] == 'system' %}{{ '<|system|>\n' + message['content'] + eos_token }}\
@@ -1422,6 +1423,7 @@ def test_run_chat_style_ft_using_dataconfig_for_chat_template(
14221423
with tempfile.TemporaryDirectory() as tempdir:
14231424

14241425
data_args = copy.deepcopy(DATA_ARGS)
1426+
data_args.training_data_path = None
14251427
if dataconfigfile == DATA_CONFIG_MULTITURN_GRANITE_3_1B_DATA_YAML:
14261428
data_args.response_template = "<|start_of_role|>assistant<|end_of_role|>"
14271429
data_args.instruction_template = "<|start_of_role|>user<|end_of_role|>"

tuning/config/configs.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,17 @@ class TrainingArguments(transformers.TrainingArguments):
226226
for all PEFT runs by the library internally."
227227
},
228228
)
229+
eval_strategy: str = field(
230+
default="no",
231+
metadata={
232+
"help": "The evaluation strategy to adopt during training. "
233+
"Possible values are 'no' (no evaluation during training), "
234+
"'epoch' (evaluate at the end of each epoch), "
235+
"'steps' (evaluate every `eval_steps`). "
236+
"Note: Splitting the dataset does not automatically trigger evaluation; "
237+
"you must explicitly set this value to enable evaluation."
238+
},
239+
)
229240

230241

231242
@dataclass

tuning/data/data_config.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,15 @@ class DataSetConfig:
3838
builder: Optional[str] = None # Referring to Hugging Face dataset builder
3939
sampling: Optional[float] = None
4040
data_handlers: Optional[List[DataHandlerConfig]] = None
41+
split: Optional[Dict[str, float]] = None
4142

4243

4344
@dataclass
4445
class DataPreProcessorConfig:
4546
type: Optional[str] = "default"
4647
sampling_stopping_strategy: Optional[str] = "all_exhausted"
4748
# Default seed is not none to ensure reproducability
48-
sampling_seed: Optional[float] = 42
49+
seed: Optional[float] = 42
4950
streaming: Optional[bool] = False
5051
chat_template: Optional[str] = None
5152

@@ -120,6 +121,17 @@ def _validate_dataset_config(dataset_config) -> DataSetConfig:
120121
c.data_handlers = []
121122
for handler in kwargs["data_handlers"]:
122123
c.data_handlers.append(_validate_data_handler_config(handler))
124+
if "split" in kwargs and kwargs["split"] is not None:
125+
split = kwargs["split"]
126+
assert isinstance(
127+
split, dict
128+
), "split must be a dictionary of split_name: ratio"
129+
for key, value in split.items():
130+
assert isinstance(key, str), f"split key '{key}' must be a string"
131+
assert (
132+
isinstance(value, (float, int)) and 0.0 <= value <= 1.0
133+
), f"split ratio for '{key}' must be a float in [0.0, 1.0], got {value}"
134+
c.split = {k: float(v) for k, v in split.items()}
123135
return c
124136

125137

@@ -140,10 +152,10 @@ def _validate_dataprocessor_config(dataprocessor_config) -> DataPreProcessorConf
140152
"all_exhausted",
141153
], "allowed sampling stopping strategies are all_exhausted(default) or first_exhausted"
142154
c.sampling_stopping_strategy = strategy
143-
if "sampling_seed" in kwargs:
144-
seed = kwargs["sampling_seed"]
145-
assert isinstance(seed, int), "sampling seed should be int"
146-
c.sampling_seed = seed
155+
if "seed" in kwargs:
156+
seed = kwargs["seed"]
157+
assert isinstance(seed, int), "seed should be int"
158+
c.seed = seed
147159
if "streaming" in kwargs:
148160
streaming = kwargs["streaming"]
149161
assert isinstance(streaming, bool), f"streaming: {streaming} should be a bool"

tuning/data/data_handlers.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,13 +116,12 @@ def tokenize_and_apply_input_masking(
116116
# These are made available by the data preprocessor framework
117117
try:
118118
tokenizer = kwargs["tokenizer"]
119-
column_names = kwargs["column_names"]
120119
except KeyError as e:
121120
raise RuntimeError(
122121
"Data processor failed to pass default args to data handlers"
123122
) from e
124123

125-
if column_names and (input_column_name or output_column_name) not in column_names:
124+
if (input_column_name or output_column_name) not in element:
126125
raise ValueError(
127126
f"Dataset should contain {input_column_name} \
128127
and {output_column_name} field if \

0 commit comments

Comments
 (0)