Skip to content

Commit 381fdd5

Browse files
authored
feat: Rename data handlers and add a new one for EPT scenarios (#460)
* Rename data handlers and add a data handler for EPT user case. Signed-off-by: Dushyant Behl <[email protected]> * Fix minor bug in formatting where input_ids was missing post duplication. Signed-off-by: Dushyant Behl <[email protected]> * Add docstring Signed-off-by: Dushyant Behl <[email protected]> * change name of dataset in data config yaml Signed-off-by: Dushyant Behl <[email protected]> --------- Signed-off-by: Dushyant Behl <[email protected]>
1 parent d48d483 commit 381fdd5

11 files changed

+254
-26
lines changed

docs/advanced-data-preprocessing.md

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -206,14 +206,16 @@ Users can also pass any number of `kwargs` arguments required for each data hand
206206
This library currently supports the following [preexisting data handlers](https://github.com/foundation-model-stack/fms-hf-tuning/blob/main/tuning/data/data_handlers.py#L156):
207207
- `tokenize_and_apply_input_masking`:
208208
Tokenizes input text and applies masking to the labels for causal language modeling tasks, good for input/output datasets.
209-
- `apply_dataset_formatting`:
210-
Formats a dataset by appending an EOS token to a specified field.
209+
- `add_tokenizer_eos_token`:
210+
Appends the tokenizer's EOS token to a specified dataset field.
211211
- `apply_custom_data_formatting_template`:
212212
Applies a custom template (e.g., Alpaca style) to format dataset elements.
213-
- `apply_custom_data_formatting_jinja_template`:
213+
- `apply_custom_jinja_template`:
214214
Applies a custom jinja template (e.g., Alpaca style) to format dataset elements.
215215
- `apply_tokenizer_chat_template`:
216216
Uses a tokenizer's chat template to preprocess dataset elements, good for single/multi turn chat templates.
217+
- `duplicate_columns`:
218+
Duplicates one column of the dataset to another column.
217219

218220
These handlers could be requested by their same name and users can lookup the function args from [here](https://github.com/foundation-model-stack/fms-hf-tuning/blob/main/tuning/data/data_handlers.py)
219221

tests/artifacts/predefined_data_configs/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,6 @@
3434
DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML = os.path.join(
3535
PREDEFINED_DATA_CONFIGS, "multiple_datasets_with_sampling.yaml"
3636
)
37+
DATA_CONFIG_DUPLICATE_COLUMNS = os.path.join(
38+
PREDEFINED_DATA_CONFIGS, "duplicate_columns.yaml"
39+
)

tests/artifacts/predefined_data_configs/apply_custom_jinja_template.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ datasets:
55
data_paths:
66
- "FILE_PATH"
77
data_handlers:
8-
- name: apply_custom_data_formatting_jinja_template
8+
- name: apply_custom_jinja_template
99
arguments:
1010
remove_columns: all
1111
batched: false
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
dataprocessor:
2+
type: default
3+
datasets:
4+
- name: pre_tokenized_with_only_input_ids
5+
data_paths:
6+
- "FILE_PATH"
7+
data_handlers:
8+
- name: duplicate_columns
9+
arguments:
10+
remove_columns: all
11+
batched: false
12+
fn_kwargs:
13+
old_column: "input_ids"
14+
new_column: "labels"

tests/artifacts/testdata/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,10 @@
5353
TWITTER_COMPLAINTS_TOKENIZED_JSON = os.path.join(
5454
JSON_DATA_DIR, "twitter_complaints_tokenized_with_maykeye_tinyllama_v0.json"
5555
)
56+
TWITTER_COMPLAINTS_TOKENIZED_ONLY_INPUT_IDS_JSON = os.path.join(
57+
JSON_DATA_DIR,
58+
"twitter_complaints_tokenized_with_maykeye_tinyllama_v0_only_input_ids.json",
59+
)
5660
TWITTER_COMPLAINTS_TOKENIZED_JSONL = os.path.join(
5761
JSONL_DATA_DIR, "twitter_complaints_tokenized_with_maykeye_tinyllama_v0.jsonl"
5862
)
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
[
2+
{
3+
"input_ids": [1, 16121, 9211, 31871, 1662, 31866, 31856, 7416, 17632, 369, 1398, 433, 322, 629, 712, 1784, 13, 13, 8458, 31922, 21597, 31871, 697, 9566]
4+
},
5+
{
6+
"input_ids": [1, 16121, 9211, 31871, 1662, 31892, 1260, 31825, 11273, 503, 31857, 632, 5284, 365, 329, 553, 1280, 31905, 960, 365, 6194, 289, 11025, 31844, 365, 473, 987, 12207, 4218, 389, 31822, 31853, 31854, 31886, 31852, 31852, 31854, 11300, 31847, 3873, 1507, 31843, 13, 13, 8458, 31922, 21597, 31871, 697, 9566]
7+
},
8+
{
9+
"input_ids": [1, 16121, 9211, 31871, 960, 312, 473, 31876, 31824, 685, 629, 31822, 31878, 4449, 5861, 287, 1662, 1299, 1574, 1590, 31833, 263, 1360, 1299, 1574, 289, 623, 31822, 31824, 16346, 312, 31876, 31836, 994, 277, 3560, 567, 31843, 672, 322, 260, 29458, 288, 629, 14881, 31843, 2628, 1423, 1662, 31858, 601, 1662, 31858, 601, 8378, 13, 13, 8458, 31922, 21597, 31871, 9566]
10+
},
11+
{
12+
"input_ids": [1, 16121, 9211, 31871, 1662, 7766, 1078, 8123, 17561, 308, 3456, 1833, 975, 10849, 291, 4372, 15379, 504, 10011, 2368, 1512, 31822, 31855, 31852, 31852, 1243, 31843, 3007, 322, 433, 31843, 13, 13, 8458, 31922, 21597, 31871, 9566]
13+
},
14+
{
15+
"input_ids": [1, 16121, 9211, 31871, 12371, 2208, 26657, 31844, 560, 14138, 31843, 21994, 1257, 24870, 496, 31829, 8198, 19057, 13, 13, 8458, 31922, 21597, 31871, 697, 9566]
16+
},
17+
{
18+
"input_ids": [1, 16121, 9211, 31871, 1662, 31836, 651, 307, 395, 13094, 672, 1467, 701, 333, 515, 31844, 504, 1097, 2266, 282, 305, 781, 31902, 21626, 31822, 31824, 5540, 397, 560, 5253, 662, 365, 31876, 263, 4985, 31854, 8903, 16801, 291, 612, 31925, 2011, 1129, 31824, 31843, 1358, 31873, 19919, 31824, 31865, 31829, 469, 2131, 31874, 13, 13, 8458, 31922, 21597, 31871, 697, 9566]
19+
},
20+
{
21+
"input_ids": [1, 16121, 9211, 31871, 1662, 31900, 307, 31837, 473, 382, 685, 266, 3195, 17532, 329, 260, 1173, 9363, 352, 1671, 1881, 646, 619, 31822, 31882, 5556, 504, 2091, 31822, 31882, 31843, 31855, 31861, 405, 499, 382, 863, 260, 31822, 31878, 4449, 2540, 2042, 31902, 13, 13, 8458, 31922, 21597, 31871, 697, 9566]
22+
},
23+
{
24+
"input_ids": [1, 16121, 9211, 31871, 1662, 14390, 16373, 337, 312, 435, 697, 1579, 291, 266, 3925, 322, 1434, 291, 3877, 31843, 1456, 365, 499, 1419, 562, 433, 31902, 13, 13, 8458, 31922, 21597, 31871, 9566]
25+
},
26+
{
27+
"input_ids": [1, 16121, 9211, 31871, 7265, 7550, 389, 1662, 31856, 2226, 11596, 27771, 898, 31843, 3259, 647, 312, 498, 288, 635, 31844, 518, 3822, 397, 2168, 28910, 31873, 13627, 4107, 1708, 31843, 312, 31876, 608, 1090, 629, 10279, 289, 1662, 29966, 31831, 5605, 13, 13, 8458, 31922, 21597, 31871, 9566]
28+
},
29+
{
30+
"input_ids": [1, 16121, 9211, 31871, 1662, 31884, 1450, 7064, 31847, 6538, 30894, 4472, 289, 362, 828, 31843, 864, 685, 541, 9932, 843, 584, 18694, 31986, 13, 13, 8458, 31922, 21597, 31871, 697, 9566]
31+
}
32+
]

tests/data/test_data_handlers.py

Lines changed: 57 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,19 @@
2121
import pytest
2222

2323
# First Party
24-
from tests.artifacts.testdata import MODEL_NAME, TWITTER_COMPLAINTS_DATA_JSONL
24+
from tests.artifacts.testdata import (
25+
MODEL_NAME,
26+
TWITTER_COMPLAINTS_DATA_JSONL,
27+
TWITTER_COMPLAINTS_TOKENIZED_JSON,
28+
TWITTER_COMPLAINTS_TOKENIZED_ONLY_INPUT_IDS_JSON,
29+
)
2530

2631
# Local
2732
from tuning.data.data_handlers import (
28-
apply_custom_data_formatting_jinja_template,
2933
apply_custom_data_formatting_template,
34+
apply_custom_jinja_template,
3035
combine_sequence,
36+
duplicate_columns,
3137
)
3238

3339

@@ -66,7 +72,7 @@ def test_apply_custom_formatting_jinja_template():
6672
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
6773
formatted_dataset_field = "formatted_data_field"
6874
formatted_dataset = json_dataset.map(
69-
apply_custom_data_formatting_jinja_template,
75+
apply_custom_jinja_template,
7076
fn_kwargs={
7177
"tokenizer": tokenizer,
7278
"dataset_text_field": formatted_dataset_field,
@@ -121,7 +127,7 @@ def test_apply_custom_formatting_jinja_template_gives_error_with_wrong_keys(temp
121127
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
122128
with pytest.raises((KeyError, ValueError)):
123129
json_dataset.map(
124-
apply_custom_data_formatting_jinja_template,
130+
apply_custom_jinja_template,
125131
fn_kwargs={
126132
"tokenizer": tokenizer,
127133
"dataset_text_field": formatted_dataset_field,
@@ -162,3 +168,50 @@ def test_combine_sequence_adds_eos(input_element, output_element, expected_res):
162168
expected_res += tokenizer.eos_token
163169
assert isinstance(comb_seq, str)
164170
assert comb_seq == expected_res
171+
172+
173+
@pytest.mark.parametrize(
174+
"dataset, old, new",
175+
[
176+
(TWITTER_COMPLAINTS_DATA_JSONL, "input_ids", "labels"),
177+
(TWITTER_COMPLAINTS_TOKENIZED_JSON, "input_ids", "labels"),
178+
(TWITTER_COMPLAINTS_DATA_JSONL, None, None),
179+
(TWITTER_COMPLAINTS_DATA_JSONL, "input_ids", None),
180+
],
181+
)
182+
def test_duplicate_columns_throws_error_on_wrong_args(dataset, old, new):
183+
"""Ensure that duplicate_columns data handler throws error if column names are wrong."""
184+
d = datasets.load_dataset("json", data_files=dataset)
185+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
186+
with pytest.raises(ValueError):
187+
d.map(
188+
duplicate_columns,
189+
fn_kwargs={
190+
"tokenizer": tokenizer,
191+
"old_column": old,
192+
"new_column": new,
193+
},
194+
)
195+
196+
197+
def test_duplicate_columns_copies_columns():
198+
"""Ensure that duplicate_columns data handler copies and maintains both columns."""
199+
old = "input_ids"
200+
new = "labels"
201+
d = datasets.load_dataset(
202+
"json", data_files=TWITTER_COMPLAINTS_TOKENIZED_ONLY_INPUT_IDS_JSON
203+
)
204+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
205+
updated_dataaset = d.map(
206+
duplicate_columns,
207+
fn_kwargs={
208+
"tokenizer": tokenizer,
209+
"old_column": old,
210+
"new_column": new,
211+
},
212+
)
213+
214+
first_element = updated_dataaset["train"][0]
215+
assert new in first_element
216+
assert old in first_element
217+
assert first_element[new] == first_element[old]
File renamed without changes.

tests/test_sft_trainer.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from build.utils import serialize_args
3737
from scripts.run_inference import TunedCausalLM
3838
from tests.artifacts.predefined_data_configs import (
39+
DATA_CONFIG_DUPLICATE_COLUMNS,
3940
DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML,
4041
DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML,
4142
)
@@ -58,6 +59,7 @@
5859
TWITTER_COMPLAINTS_TOKENIZED_ARROW,
5960
TWITTER_COMPLAINTS_TOKENIZED_JSON,
6061
TWITTER_COMPLAINTS_TOKENIZED_JSONL,
62+
TWITTER_COMPLAINTS_TOKENIZED_ONLY_INPUT_IDS_JSON,
6163
TWITTER_COMPLAINTS_TOKENIZED_PARQUET,
6264
)
6365

@@ -71,7 +73,7 @@
7173
DataPreProcessorConfig,
7274
DataSetConfig,
7375
)
74-
from tuning.data.data_handlers import apply_dataset_formatting
76+
from tuning.data.data_handlers import add_tokenizer_eos_token
7577

7678
MODEL_ARGS = configs.ModelArguments(
7779
model_name_or_path=MODEL_NAME, use_flash_attn=False, torch_dtype="float32"
@@ -880,6 +882,52 @@ def test_run_causallm_ft_and_inference_with_multiple_dataset(
880882
assert "### Text: @NortonSupport Thanks much.\n\n### Label:" in output_inference
881883

882884

885+
def test_run_training_with_pretokenised_dataset_containing_input_ids():
886+
"""Ensure that we can train on pretokenised dataset containing just input_ids by
887+
choosing duplicate_columns data handler via data config."""
888+
with tempfile.TemporaryDirectory() as tempdir:
889+
890+
data_args = copy.deepcopy(DATA_ARGS)
891+
892+
# set training_data_path and response_template to none
893+
data_args.response_template = None
894+
data_args.training_data_path = None
895+
896+
dataconfigfile = DATA_CONFIG_DUPLICATE_COLUMNS
897+
datapath = TWITTER_COMPLAINTS_TOKENIZED_ONLY_INPUT_IDS_JSON
898+
899+
# add data_paths in data_config file
900+
with tempfile.NamedTemporaryFile(
901+
"w", delete=False, suffix=".yaml"
902+
) as temp_yaml_file:
903+
with open(dataconfigfile, "r", encoding="utf-8") as f:
904+
data = yaml.safe_load(f)
905+
datasets = data["datasets"]
906+
for _, d in enumerate(datasets):
907+
d["data_paths"] = [datapath]
908+
yaml.dump(data, temp_yaml_file)
909+
data_args.data_config_path = temp_yaml_file.name
910+
911+
train_args = copy.deepcopy(TRAIN_ARGS)
912+
train_args.output_dir = tempdir
913+
914+
sft_trainer.train(MODEL_ARGS, data_args, train_args)
915+
916+
# validate full ft configs
917+
_validate_training(tempdir)
918+
checkpoint_path = _get_checkpoint_path(tempdir)
919+
920+
# Load the model
921+
loaded_model = TunedCausalLM.load(checkpoint_path, MODEL_NAME)
922+
923+
# Run inference on the text
924+
output_inference = loaded_model.run(
925+
"### Text: @NortonSupport Thanks much.\n\n### Label:", max_new_tokens=50
926+
)
927+
assert len(output_inference) > 0
928+
assert "### Text: @NortonSupport Thanks much.\n\n### Label:" in output_inference
929+
930+
883931
@pytest.mark.parametrize(
884932
"dataset_path",
885933
[CHAT_DATA_SINGLE_TURN, CHAT_DATA_MULTI_TURN],
@@ -1469,7 +1517,7 @@ def test_run_by_passing_additional_data_handlers():
14691517
TEST_HANDLER = "my_test_handler"
14701518

14711519
def test_handler(element, tokenizer, **kwargs):
1472-
return apply_dataset_formatting(element, tokenizer, "custom_formatted_field")
1520+
return add_tokenizer_eos_token(element, tokenizer, "custom_formatted_field")
14731521

14741522
# This data config calls for data handler to be applied to dataset
14751523
preprocessor_config = DataPreProcessorConfig()

0 commit comments

Comments
 (0)