Skip to content

Commit fbd61d0

Browse files
committed
Merge branch 'main' into save-model-dir-moe-checkpoint
Signed-off-by: Will Johnson <mwjohnson728@gmail.com>
2 parents bc9cbba + 1c9f773 commit fbd61d0

File tree

3 files changed

+49
-39
lines changed

3 files changed

+49
-39
lines changed

tests/test_sft_trainer.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -761,17 +761,18 @@ def test_run_causallm_ft_save_with_save_model_dir_save_strategy_no():
761761

762762

763763
@pytest.mark.parametrize(
764-
"dataset_path",
764+
"dataset_path, packing",
765765
[
766-
TWITTER_COMPLAINTS_TOKENIZED_JSONL,
767-
TWITTER_COMPLAINTS_TOKENIZED_JSON,
768-
TWITTER_COMPLAINTS_TOKENIZED_PARQUET,
769-
TWITTER_COMPLAINTS_TOKENIZED_ARROW,
766+
(TWITTER_COMPLAINTS_TOKENIZED_JSONL, True),
767+
(TWITTER_COMPLAINTS_TOKENIZED_PARQUET, True),
768+
(TWITTER_COMPLAINTS_TOKENIZED_JSON, False),
769+
(TWITTER_COMPLAINTS_TOKENIZED_ARROW, False),
770770
],
771771
)
772-
def test_run_causallm_ft_pretokenized(dataset_path):
772+
def test_run_causallm_ft_pretokenized(dataset_path, packing):
773773
"""Check if we can bootstrap and finetune causallm models using pretokenized data"""
774774
with tempfile.TemporaryDirectory() as tempdir:
775+
775776
data_formatting_args = copy.deepcopy(DATA_ARGS)
776777

777778
# below args not needed for pretokenized data
@@ -784,6 +785,8 @@ def test_run_causallm_ft_pretokenized(dataset_path):
784785

785786
train_args = copy.deepcopy(TRAIN_ARGS)
786787
train_args.output_dir = tempdir
788+
train_args.packing = packing
789+
train_args.max_seq_length = 256
787790

788791
sft_trainer.train(MODEL_ARGS, data_formatting_args, train_args)
789792

tuning/data/data_preprocessing_utils.py

Lines changed: 39 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -55,44 +55,55 @@ def get_data_collator(
5555
Callable collator to be leveraged by the trainer.
5656
"""
5757

58+
if packing:
59+
if is_traindata_tokenized:
60+
# packing with tokenized dataset requires seq2seq collator.
61+
return DataCollatorForSeq2Seq(
62+
tokenizer=tokenizer, padding=False, max_length=max_seq_length
63+
)
64+
65+
# packing for non tokenized dataset doesn't require a collator with SFTrainer.
66+
return None
67+
68+
# TODO: near term - how response template ids are parsed out needs to be cleaned.
69+
# The [2:] here applies if response template has \n prefix, it is needed to strip \n,
70+
# otherwise template is not found. We will create issue to clean this out after we discuss
71+
# data formats and collators we will support.
5872
if response_template and instruction_template:
73+
# Pass both instruction and response template for chat style training.
5974
return DataCollatorForCompletionOnlyLM(
6075
response_template=response_template,
6176
instruction_template=instruction_template,
6277
tokenizer=tokenizer,
6378
ignore_index=configs.IGNORE_INDEX,
6479
)
6580

66-
if not packing:
67-
# TODO: near term - how response template ids are parsed out needs to be cleaned.
68-
# The [2:] here applies if response template has \n prefix, it is needed to strip \n,
69-
# otherwise template is not found. We will create issue to clean this out after we discuss
70-
# data formats and collators we will support.
71-
if response_template:
72-
response_template_ids = tokenizer.encode(
73-
response_template, add_special_tokens=False
74-
)[2:]
75-
return DataCollatorForCompletionOnlyLM(
76-
response_template=response_template_ids,
77-
tokenizer=tokenizer,
78-
ignore_index=configs.IGNORE_INDEX,
79-
)
81+
if response_template:
82+
response_template_ids = tokenizer.encode(
83+
response_template, add_special_tokens=False
84+
)[2:]
85+
return DataCollatorForCompletionOnlyLM(
86+
response_template=response_template_ids,
87+
tokenizer=tokenizer,
88+
ignore_index=configs.IGNORE_INDEX,
89+
)
8090

81-
if is_padding_free:
82-
# when packing is false but padding_free is used and
83-
# no response template is used then its a pretrained scenario.
84-
# Current plugin in fms-acceleration is compatible with
85-
# `DataCollatorForSeq2Seq` collator hence we use this.
86-
return DataCollatorForSeq2Seq(
87-
tokenizer=tokenizer, padding=False, max_length=max_seq_length
88-
)
91+
if is_padding_free:
92+
# when packing is false but padding_free is used and
93+
# no response template is used then its a pretrained scenario.
94+
# Current plugin in fms-acceleration is compatible with
95+
# `DataCollatorForSeq2Seq` collator hence we use this.
96+
return DataCollatorForSeq2Seq(
97+
tokenizer=tokenizer, padding=False, max_length=max_seq_length
98+
)
8999

100+
if is_traindata_tokenized:
90101
# Note that this automatically pads labels with -100
91102
# TODO check if this is sufficient for preprocessed
92-
if is_traindata_tokenized:
93-
return DataCollatorForSeq2Seq(
94-
tokenizer=tokenizer, padding=True, max_length=max_seq_length
95-
)
96-
raise ValueError(
97-
"Could not pick a data collator. Please refer to supported data formats"
103+
return DataCollatorForSeq2Seq(
104+
tokenizer=tokenizer, padding=True, max_length=max_seq_length
98105
)
106+
107+
raise ValueError(
108+
"Could not pick a data collator. Please refer to supported data formats"
109+
)

tuning/data/setup_dataprocessor.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -396,10 +396,6 @@ def process_dataargs(
396396
is_padding_free=is_padding_free,
397397
)
398398

399-
dataset_kwargs = {}
400-
if is_tokenized_dataset:
401-
dataset_kwargs["skip_prepare_dataset"] = True
402-
403399
if isinstance(train_dataset, IterableDataset):
404400
train_args.accelerator_config = {"split_batches": True}
405401
logger.info(
@@ -415,5 +411,5 @@ def process_dataargs(
415411
dataset_text_field,
416412
data_collator,
417413
max_seq_length,
418-
dataset_kwargs,
414+
None,
419415
)

0 commit comments

Comments
 (0)