diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index a69446231..037c0420f 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -761,17 +761,18 @@ def test_run_causallm_ft_save_with_save_model_dir_save_strategy_no(): @pytest.mark.parametrize( - "dataset_path", + "dataset_path, packing", [ - TWITTER_COMPLAINTS_TOKENIZED_JSONL, - TWITTER_COMPLAINTS_TOKENIZED_JSON, - TWITTER_COMPLAINTS_TOKENIZED_PARQUET, - TWITTER_COMPLAINTS_TOKENIZED_ARROW, + (TWITTER_COMPLAINTS_TOKENIZED_JSONL, True), + (TWITTER_COMPLAINTS_TOKENIZED_PARQUET, True), + (TWITTER_COMPLAINTS_TOKENIZED_JSON, False), + (TWITTER_COMPLAINTS_TOKENIZED_ARROW, False), ], ) -def test_run_causallm_ft_pretokenized(dataset_path): +def test_run_causallm_ft_pretokenized(dataset_path, packing): """Check if we can bootstrap and finetune causallm models using pretokenized data""" with tempfile.TemporaryDirectory() as tempdir: + data_formatting_args = copy.deepcopy(DATA_ARGS) # below args not needed for pretokenized data @@ -784,6 +785,8 @@ def test_run_causallm_ft_pretokenized(dataset_path): train_args = copy.deepcopy(TRAIN_ARGS) train_args.output_dir = tempdir + train_args.packing = packing + train_args.max_seq_length = 256 sft_trainer.train(MODEL_ARGS, data_formatting_args, train_args) diff --git a/tuning/data/data_preprocessing_utils.py b/tuning/data/data_preprocessing_utils.py index b77fdba1d..a4abbaa70 100644 --- a/tuning/data/data_preprocessing_utils.py +++ b/tuning/data/data_preprocessing_utils.py @@ -55,7 +55,22 @@ def get_data_collator( Callable collator to be leveraged by the trainer. """ + if packing: + if is_traindata_tokenized: + # packing with tokenized dataset requires seq2seq collator. + return DataCollatorForSeq2Seq( + tokenizer=tokenizer, padding=False, max_length=max_seq_length + ) + + # packing for non tokenized dataset doesn't require a collator with SFTrainer. + return None + + # TODO: near term - how response template ids are parsed out needs to be cleaned. + # The [2:] here applies if response template has \n prefix, it is needed to strip \n, + # otherwise template is not found. We will create issue to clean this out after we discuss + # data formats and collators we will support. if response_template and instruction_template: + # Pass both instruction and response template for chat style training. return DataCollatorForCompletionOnlyLM( response_template=response_template, instruction_template=instruction_template, @@ -63,36 +78,32 @@ def get_data_collator( ignore_index=configs.IGNORE_INDEX, ) - if not packing: - # TODO: near term - how response template ids are parsed out needs to be cleaned. - # The [2:] here applies if response template has \n prefix, it is needed to strip \n, - # otherwise template is not found. We will create issue to clean this out after we discuss - # data formats and collators we will support. - if response_template: - response_template_ids = tokenizer.encode( - response_template, add_special_tokens=False - )[2:] - return DataCollatorForCompletionOnlyLM( - response_template=response_template_ids, - tokenizer=tokenizer, - ignore_index=configs.IGNORE_INDEX, - ) + if response_template: + response_template_ids = tokenizer.encode( + response_template, add_special_tokens=False + )[2:] + return DataCollatorForCompletionOnlyLM( + response_template=response_template_ids, + tokenizer=tokenizer, + ignore_index=configs.IGNORE_INDEX, + ) - if is_padding_free: - # when packing is false but padding_free is used and - # no response template is used then its a pretrained scenario. - # Current plugin in fms-acceleration is compatible with - # `DataCollatorForSeq2Seq` collator hence we use this. - return DataCollatorForSeq2Seq( - tokenizer=tokenizer, padding=False, max_length=max_seq_length - ) + if is_padding_free: + # when packing is false but padding_free is used and + # no response template is used then its a pretrained scenario. + # Current plugin in fms-acceleration is compatible with + # `DataCollatorForSeq2Seq` collator hence we use this. + return DataCollatorForSeq2Seq( + tokenizer=tokenizer, padding=False, max_length=max_seq_length + ) + if is_traindata_tokenized: # Note that this automatically pads labels with -100 # TODO check if this is sufficient for preprocessed - if is_traindata_tokenized: - return DataCollatorForSeq2Seq( - tokenizer=tokenizer, padding=True, max_length=max_seq_length - ) - raise ValueError( - "Could not pick a data collator. Please refer to supported data formats" + return DataCollatorForSeq2Seq( + tokenizer=tokenizer, padding=True, max_length=max_seq_length ) + + raise ValueError( + "Could not pick a data collator. Please refer to supported data formats" + ) diff --git a/tuning/data/setup_dataprocessor.py b/tuning/data/setup_dataprocessor.py index fc2789c69..bb3ab8b81 100644 --- a/tuning/data/setup_dataprocessor.py +++ b/tuning/data/setup_dataprocessor.py @@ -396,10 +396,6 @@ def process_dataargs( is_padding_free=is_padding_free, ) - dataset_kwargs = {} - if is_tokenized_dataset: - dataset_kwargs["skip_prepare_dataset"] = True - if isinstance(train_dataset, IterableDataset): train_args.accelerator_config = {"split_batches": True} logger.info( @@ -415,5 +411,5 @@ def process_dataargs( dataset_text_field, data_collator, max_seq_length, - dataset_kwargs, + None, )