Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,17 +761,18 @@ def test_run_causallm_ft_save_with_save_model_dir_save_strategy_no():


@pytest.mark.parametrize(
"dataset_path",
"dataset_path, packing",
[
TWITTER_COMPLAINTS_TOKENIZED_JSONL,
TWITTER_COMPLAINTS_TOKENIZED_JSON,
TWITTER_COMPLAINTS_TOKENIZED_PARQUET,
TWITTER_COMPLAINTS_TOKENIZED_ARROW,
(TWITTER_COMPLAINTS_TOKENIZED_JSONL, True),
(TWITTER_COMPLAINTS_TOKENIZED_PARQUET, True),
(TWITTER_COMPLAINTS_TOKENIZED_JSON, False),
(TWITTER_COMPLAINTS_TOKENIZED_ARROW, False),
],
)
def test_run_causallm_ft_pretokenized(dataset_path):
def test_run_causallm_ft_pretokenized(dataset_path, packing):
"""Check if we can bootstrap and finetune causallm models using pretokenized data"""
with tempfile.TemporaryDirectory() as tempdir:

data_formatting_args = copy.deepcopy(DATA_ARGS)

# below args not needed for pretokenized data
Expand All @@ -784,6 +785,8 @@ def test_run_causallm_ft_pretokenized(dataset_path):

train_args = copy.deepcopy(TRAIN_ARGS)
train_args.output_dir = tempdir
train_args.packing = packing
train_args.max_seq_length = 256

sft_trainer.train(MODEL_ARGS, data_formatting_args, train_args)

Expand Down
67 changes: 39 additions & 28 deletions tuning/data/data_preprocessing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,44 +55,55 @@ def get_data_collator(
Callable collator to be leveraged by the trainer.
"""

if packing:
if is_traindata_tokenized:
# packing with tokenized dataset requires seq2seq collator.
return DataCollatorForSeq2Seq(
tokenizer=tokenizer, padding=False, max_length=max_seq_length
)

# packing for non tokenized dataset doesn't require a collator with SFTrainer.
return None

# TODO: near term - how response template ids are parsed out needs to be cleaned.
# The [2:] here applies if response template has \n prefix, it is needed to strip \n,
# otherwise template is not found. We will create issue to clean this out after we discuss
# data formats and collators we will support.
if response_template and instruction_template:
# Pass both instruction and response template for chat style training.
return DataCollatorForCompletionOnlyLM(
response_template=response_template,
instruction_template=instruction_template,
tokenizer=tokenizer,
ignore_index=configs.IGNORE_INDEX,
)

if not packing:
# TODO: near term - how response template ids are parsed out needs to be cleaned.
# The [2:] here applies if response template has \n prefix, it is needed to strip \n,
# otherwise template is not found. We will create issue to clean this out after we discuss
# data formats and collators we will support.
if response_template:
response_template_ids = tokenizer.encode(
response_template, add_special_tokens=False
)[2:]
return DataCollatorForCompletionOnlyLM(
response_template=response_template_ids,
tokenizer=tokenizer,
ignore_index=configs.IGNORE_INDEX,
)
if response_template:
response_template_ids = tokenizer.encode(
response_template, add_special_tokens=False
)[2:]
return DataCollatorForCompletionOnlyLM(
response_template=response_template_ids,
tokenizer=tokenizer,
ignore_index=configs.IGNORE_INDEX,
)

if is_padding_free:
# when packing is false but padding_free is used and
# no response template is used then its a pretrained scenario.
# Current plugin in fms-acceleration is compatible with
# `DataCollatorForSeq2Seq` collator hence we use this.
return DataCollatorForSeq2Seq(
tokenizer=tokenizer, padding=False, max_length=max_seq_length
)
if is_padding_free:
# when packing is false but padding_free is used and
# no response template is used then its a pretrained scenario.
# Current plugin in fms-acceleration is compatible with
# `DataCollatorForSeq2Seq` collator hence we use this.
return DataCollatorForSeq2Seq(
tokenizer=tokenizer, padding=False, max_length=max_seq_length
)

if is_traindata_tokenized:
# Note that this automatically pads labels with -100
# TODO check if this is sufficient for preprocessed
if is_traindata_tokenized:
return DataCollatorForSeq2Seq(
tokenizer=tokenizer, padding=True, max_length=max_seq_length
)
raise ValueError(
"Could not pick a data collator. Please refer to supported data formats"
return DataCollatorForSeq2Seq(
tokenizer=tokenizer, padding=True, max_length=max_seq_length
)

raise ValueError(
"Could not pick a data collator. Please refer to supported data formats"
)
6 changes: 1 addition & 5 deletions tuning/data/setup_dataprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,10 +396,6 @@ def process_dataargs(
is_padding_free=is_padding_free,
)

dataset_kwargs = {}
if is_tokenized_dataset:
dataset_kwargs["skip_prepare_dataset"] = True

Comment on lines -399 to -402
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason we dont want to skip anymore?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes...added explanation here #468 (comment)
We need to remove this check to call prepare_dataset and and enable packing for the pretokenized dataset.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just checked, looks like it checks for tokenized with trl and skips prepare.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if isinstance(train_dataset, IterableDataset):
train_args.accelerator_config = {"split_batches": True}
logger.info(
Expand All @@ -415,5 +411,5 @@ def process_dataargs(
dataset_text_field,
data_collator,
max_seq_length,
dataset_kwargs,
None,
)