Skip to content

Commit 73b81c1

Browse files
committed
Add seq2seq collator for packing pretokenized data
Signed-off-by: Dushyant Behl <dushyantbehl@in.ibm.com>
1 parent 87d987a commit 73b81c1

File tree

2 files changed

+14
-6
lines changed

2 files changed

+14
-6
lines changed

tests/test_sft_trainer.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -756,17 +756,18 @@ def test_run_causallm_ft_save_with_save_model_dir_save_strategy_no():
756756

757757

758758
@pytest.mark.parametrize(
759-
"dataset_path",
759+
"dataset_path, packing",
760760
[
761-
TWITTER_COMPLAINTS_TOKENIZED_JSONL,
762-
TWITTER_COMPLAINTS_TOKENIZED_JSON,
763-
TWITTER_COMPLAINTS_TOKENIZED_PARQUET,
764-
TWITTER_COMPLAINTS_TOKENIZED_ARROW,
761+
(TWITTER_COMPLAINTS_TOKENIZED_JSON, False),
762+
(TWITTER_COMPLAINTS_TOKENIZED_JSONL, True),
763+
(TWITTER_COMPLAINTS_TOKENIZED_PARQUET, True),
764+
(TWITTER_COMPLAINTS_TOKENIZED_ARROW, False),
765765
],
766766
)
767-
def test_run_causallm_ft_pretokenized(dataset_path):
767+
def test_run_causallm_ft_pretokenized(dataset_path, packing):
768768
"""Check if we can bootstrap and finetune causallm models using pretokenized data"""
769769
with tempfile.TemporaryDirectory() as tempdir:
770+
770771
data_formatting_args = copy.deepcopy(DATA_ARGS)
771772

772773
# below args not needed for pretokenized data
@@ -779,6 +780,8 @@ def test_run_causallm_ft_pretokenized(dataset_path):
779780

780781
train_args = copy.deepcopy(TRAIN_ARGS)
781782
train_args.output_dir = tempdir
783+
train_args.packing = packing
784+
train_args.max_seq_length = 256
782785

783786
sft_trainer.train(MODEL_ARGS, data_formatting_args, train_args)
784787

tuning/data/data_preprocessing_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,3 +96,8 @@ def get_data_collator(
9696
raise ValueError(
9797
"Could not pick a data collator. Please refer to supported data formats"
9898
)
99+
100+
if is_traindata_tokenized:
101+
return DataCollatorForSeq2Seq(
102+
tokenizer=tokenizer, padding=False, max_length=max_seq_length
103+
)

0 commit comments

Comments
 (0)