Skip to content

Commit 53a9d18

Browse files
dushyantbehlHarikrishnanBalagopalkmehant
authored
feat: allow for padding free plugin to be used without response template (#430)
* fix: allow for padding free + pretraining Signed-off-by: Harikrishnan Balagopal <[email protected]> * add data collator for padding free plugin scenario to be used for extended pretraining Signed-off-by: Dushyant Behl <[email protected]> * fix: update value error Signed-off-by: Mehant Kammakomati <[email protected]> * fix: delete images only when exists Signed-off-by: Mehant Kammakomati <[email protected]> --------- Signed-off-by: Harikrishnan Balagopal <[email protected]> Signed-off-by: Dushyant Behl <[email protected]> Signed-off-by: Mehant Kammakomati <[email protected]> Co-authored-by: Harikrishnan Balagopal <[email protected]> Co-authored-by: Mehant Kammakomati <[email protected]>
1 parent 8851227 commit 53a9d18

File tree

6 files changed

+99
-23
lines changed

6 files changed

+99
-23
lines changed

.github/workflows/image.yaml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,8 @@ jobs:
1515
sudo swapoff -a
1616
sudo rm -f /swapfile
1717
sudo apt clean
18-
docker rmi $(docker image ls -aq)
18+
if [ "$(docker image ls -q)" ]; then docker rmi $(docker image ls -aq); fi
1919
df -h
2020
- name: Build image
2121
run: |
2222
docker build -t fms-hf-tuning:dev . -f build/Dockerfile
23-

tests/data/test_data_preprocessing_utils.py

Lines changed: 49 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -489,7 +489,7 @@ def test_is_pretokenized_data(data, result):
489489

490490
@pytest.mark.parametrize(
491491
"packing, response_template, formatted_train_dataset,\
492-
max_seq_length, instruction_template, expected_collator",
492+
max_seq_length, instruction_template, is_padding_free, expected_collator",
493493
[
494494
(
495495
False,
@@ -501,6 +501,7 @@ def test_is_pretokenized_data(data, result):
501501
),
502502
1024,
503503
None,
504+
False,
504505
DataCollatorForCompletionOnlyLM,
505506
),
506507
(
@@ -517,6 +518,7 @@ def test_is_pretokenized_data(data, result):
517518
),
518519
1024,
519520
None,
521+
False,
520522
DataCollatorForSeq2Seq,
521523
),
522524
(
@@ -529,6 +531,7 @@ def test_is_pretokenized_data(data, result):
529531
),
530532
1024,
531533
"\n### Text:",
534+
False,
532535
DataCollatorForCompletionOnlyLM,
533536
),
534537
(
@@ -545,6 +548,20 @@ def test_is_pretokenized_data(data, result):
545548
),
546549
1024,
547550
"\n### Text:",
551+
False,
552+
DataCollatorForSeq2Seq,
553+
),
554+
(
555+
False,
556+
None,
557+
datasets.load_dataset(
558+
"json",
559+
data_files=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON,
560+
split="train",
561+
),
562+
1024,
563+
None,
564+
True,
548565
DataCollatorForSeq2Seq,
549566
),
550567
],
@@ -555,6 +572,7 @@ def test_get_data_collator(
555572
formatted_train_dataset,
556573
max_seq_length,
557574
instruction_template,
575+
is_padding_free,
558576
expected_collator,
559577
):
560578
"""Ensure that the correct collator type is fetched based on the data args"""
@@ -565,6 +583,7 @@ def test_get_data_collator(
565583
is_pretokenized_dataset(formatted_train_dataset),
566584
max_seq_length,
567585
instruction_template,
586+
is_padding_free,
568587
)
569588
assert isinstance(collator, expected_collator)
570589

@@ -1044,7 +1063,7 @@ def test_process_dataconfig_multiple_datasets_datafiles_sampling(
10441063

10451064

10461065
@pytest.mark.parametrize(
1047-
"data_args",
1066+
"data_args, is_padding_free",
10481067
[
10491068
# single sequence JSON and response template
10501069
(
@@ -1053,7 +1072,8 @@ def test_process_dataconfig_multiple_datasets_datafiles_sampling(
10531072
validation_data_path=TWITTER_COMPLAINTS_DATA_JSON,
10541073
dataset_text_field="output",
10551074
response_template="\n### Label:",
1056-
)
1075+
),
1076+
False,
10571077
),
10581078
# single sequence JSONL and response template
10591079
(
@@ -1062,7 +1082,8 @@ def test_process_dataconfig_multiple_datasets_datafiles_sampling(
10621082
validation_data_path=TWITTER_COMPLAINTS_DATA_JSONL,
10631083
dataset_text_field="output",
10641084
response_template="\n### Label:",
1065-
)
1085+
),
1086+
False,
10661087
),
10671088
# single sequence PARQUET and response template
10681089
(
@@ -1071,7 +1092,8 @@ def test_process_dataconfig_multiple_datasets_datafiles_sampling(
10711092
validation_data_path=TWITTER_COMPLAINTS_DATA_PARQUET,
10721093
dataset_text_field="output",
10731094
response_template="\n### Label:",
1074-
)
1095+
),
1096+
False,
10751097
),
10761098
# data formatter template with input/output JSON
10771099
(
@@ -1080,7 +1102,8 @@ def test_process_dataconfig_multiple_datasets_datafiles_sampling(
10801102
validation_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON,
10811103
data_formatter_template="### Text:{{input}} \n\n### Label: {{output}}",
10821104
response_template="\n### Label:",
1083-
)
1105+
),
1106+
False,
10841107
),
10851108
# data formatter template with input/output JSONL
10861109
(
@@ -1089,7 +1112,8 @@ def test_process_dataconfig_multiple_datasets_datafiles_sampling(
10891112
validation_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL,
10901113
data_formatter_template="### Text:{{input}} \n\n### Label: {{output}}",
10911114
response_template="\n### Label:",
1092-
)
1115+
),
1116+
False,
10931117
),
10941118
# data formatter template with input/output PARQUET
10951119
(
@@ -1098,32 +1122,44 @@ def test_process_dataconfig_multiple_datasets_datafiles_sampling(
10981122
validation_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET,
10991123
data_formatter_template="### Text:{{input}} \n\n### Label: {{output}}",
11001124
response_template="\n### Label:",
1101-
)
1125+
),
1126+
False,
11021127
),
11031128
# input/output JSON with masking on input
11041129
(
11051130
configs.DataArguments(
11061131
training_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON,
11071132
validation_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON,
1108-
)
1133+
),
1134+
False,
11091135
),
11101136
# input/output JSONL with masking on input
11111137
(
11121138
configs.DataArguments(
11131139
training_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL,
11141140
validation_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL,
1115-
)
1141+
),
1142+
False,
11161143
),
11171144
# input/output PARQUET with masking on input
11181145
(
11191146
configs.DataArguments(
11201147
training_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET,
11211148
validation_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET,
1122-
)
1149+
),
1150+
False,
1151+
),
1152+
(
1153+
configs.DataArguments(
1154+
training_data_path=TWITTER_COMPLAINTS_DATA_JSON,
1155+
validation_data_path=TWITTER_COMPLAINTS_DATA_JSON,
1156+
dataset_text_field="output",
1157+
),
1158+
True,
11231159
),
11241160
],
11251161
)
1126-
def test_process_dataargs(data_args):
1162+
def test_process_dataargs(data_args, is_padding_free):
11271163
"""Ensure that the train/eval data are properly formatted based on the data args / text field"""
11281164
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
11291165
TRAIN_ARGS = configs.TrainingArguments(
@@ -1132,7 +1168,7 @@ def test_process_dataargs(data_args):
11321168
output_dir="tmp", # Not needed but positional
11331169
)
11341170
(train_set, eval_set, dataset_text_field, _, _, _) = process_dataargs(
1135-
data_args, tokenizer, TRAIN_ARGS
1171+
data_args, tokenizer, TRAIN_ARGS, is_padding_free=is_padding_free
11361172
)
11371173
assert isinstance(train_set, Dataset)
11381174
assert isinstance(eval_set, Dataset)

tuning/config/acceleration_configs/attention_and_distributed_packing.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,7 @@ class AttentionAndDistributedPackingConfig:
4747
def __post_init__(self):
4848
# ensure nested dataclasses initialized
4949
ensure_nested_dataclasses_initialized(self)
50+
51+
@property
52+
def is_padding_free(self):
53+
return self.padding_free is not None

tuning/data/data_preprocessing_utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def get_data_collator(
2929
is_traindata_tokenized: bool,
3030
max_seq_length: int,
3131
instruction_template: Optional[str],
32+
is_padding_free: bool = False,
3233
) -> Callable:
3334
"""Create and return the the appropriate collator type based on the configuration for packing,
3435
response_template, and dataset_text_field.
@@ -46,6 +47,8 @@ def get_data_collator(
4647
Max sequence length expected
4748
instruction_template: str
4849
str representing the human response in a chat template
50+
is_padding_free: bool
51+
if padding free plugin is used or not
4952
5053
Returns:
5154
Callable
@@ -74,6 +77,16 @@ def get_data_collator(
7477
tokenizer=tokenizer,
7578
ignore_index=configs.IGNORE_INDEX,
7679
)
80+
81+
if is_padding_free:
82+
# when packing is false but padding_free is used and
83+
# no response template is used then its a pretrained scenario.
84+
# Current plugin in fms-acceleration is compatible with
85+
# `DataCollatorForSeq2Seq` collator hence we use this.
86+
return DataCollatorForSeq2Seq(
87+
tokenizer=tokenizer, padding=False, max_length=max_seq_length
88+
)
89+
7790
# Note that this automatically pads labels with -100
7891
# TODO check if this is sufficient for preprocessed
7992
if is_traindata_tokenized:

tuning/data/setup_dataprocessor.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -107,15 +107,22 @@ def _get_pretokenized_dataset_handlers(data_args, packing, is_eval_tokenized):
107107

108108

109109
### Data format 2
110-
def _get_dataset_formatting_handlers(data_args, packing):
110+
def _get_dataset_formatting_handlers(data_args, packing, is_padding_free=False):
111111

112112
if data_args.response_template is None:
113113
if packing is False:
114-
raise ValueError(
115-
"Since dataset_text_field or data_formatter_template \
116-
is provided and packing is disabled, \
117-
needs a corresponding response template for masking"
118-
)
114+
if is_padding_free:
115+
logger.debug(
116+
"Assuming pretraining scenario (loss over all tokens) "
117+
+ "because, packing is false,"
118+
+ " padding_free plugin is used and no response template was provided."
119+
)
120+
else:
121+
raise ValueError(
122+
"Since response_template is not provided for masking, \
123+
either use packing or padding_free to enable \
124+
pretraining scenario (loss over all tokens)."
125+
)
119126

120127
if data_args.response_template:
121128
# To use Response template, pass datasets with single sequence instances \
@@ -209,6 +216,7 @@ def _process_raw_data_args(
209216
packing: bool,
210217
max_seq_length: int,
211218
additional_data_handlers: Dict[str, Callable] = None,
219+
is_padding_free: bool = False,
212220
):
213221

214222
# Create a data processor with default processor config
@@ -248,6 +256,7 @@ def _process_raw_data_args(
248256
tokenizer_kwargs = {}
249257
tokenizer_kwargs["max_length"] = max_seq_length
250258
tokenizer_kwargs["truncation"] = True
259+
# Lets not pad in tokenizer...we can handle that in the collator
251260
tokenizer_kwargs["padding"] = False
252261

253262
handlers = None
@@ -266,7 +275,7 @@ def _process_raw_data_args(
266275
elif data_args.data_formatter_template or data_args.dataset_text_field:
267276
# Data Format 3: Single Sequence Dataset
268277
handlers, dataset_text_field = _get_dataset_formatting_handlers(
269-
data_args, packing
278+
data_args, packing, is_padding_free
270279
)
271280
else:
272281
# Default Data Format: Dataset with Input/Output Fields
@@ -300,6 +309,7 @@ def process_dataargs(
300309
tokenizer: AutoTokenizer,
301310
train_args: TrainingArguments,
302311
additional_data_handlers: Dict[str, Callable] = None,
312+
is_padding_free: bool = False,
303313
):
304314
"""
305315
Args:
@@ -310,6 +320,8 @@ def process_dataargs(
310320
Used for packing and max_seq_length
311321
additional_data_handlers: A Dict of [str, callable] data handlers
312322
which need to be registered with the data preprocessor
323+
is_padding_free: A bool representing if Padding free plugin is enabled.
324+
Defaults to False.
313325
Returns:
314326
Tuple(Dataset, Dataset, str, DataCollator, int, Dict)
315327
tuple containing
@@ -345,6 +357,7 @@ def process_dataargs(
345357
train_args.packing,
346358
max_seq_length,
347359
additional_data_handlers,
360+
is_padding_free,
348361
)
349362

350363
# Note: This check should not be removed.
@@ -359,6 +372,7 @@ def process_dataargs(
359372
is_tokenized_dataset,
360373
max_seq_length,
361374
data_args.instruction_template,
375+
is_padding_free=is_padding_free,
362376
)
363377

364378
dataset_kwargs = {}

tuning/sft_trainer.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,10 @@ def train(
306306
data_collator = None
307307
logger.info("Packing is set to %s ", train_args.packing)
308308

309+
is_padding_free = False
310+
if attention_and_distributed_packing_config is not None:
311+
is_padding_free = attention_and_distributed_packing_config.is_padding_free
312+
309313
data_preprocessing_time = time.time()
310314
(
311315
formatted_train_dataset,
@@ -314,7 +318,13 @@ def train(
314318
data_collator,
315319
train_args.max_seq_length,
316320
dataset_kwargs,
317-
) = process_dataargs(data_args, tokenizer, train_args, additional_data_handlers)
321+
) = process_dataargs(
322+
data_args,
323+
tokenizer,
324+
train_args,
325+
additional_data_handlers,
326+
is_padding_free=is_padding_free,
327+
)
318328
additional_metrics["data_preprocessing_time"] = (
319329
time.time() - data_preprocessing_time
320330
)

0 commit comments

Comments
 (0)