Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions trl/extras/dataset_formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def format_dataset(examples):


def get_formatting_func_from_dataset(
dataset: Union[Dataset, ConstantLengthDataset], tokenizer: AutoTokenizer
dataset: Union[Dataset, ConstantLengthDataset], tokenizer: AutoTokenizer, apply_chat_instruction_template=True
) -> Optional[Callable]:
r"""
Finds the correct formatting function based on the dataset structure. Currently supported datasets are:
Expand All @@ -83,6 +83,7 @@ def get_formatting_func_from_dataset(
return conversations_formatting_function(tokenizer, "conversations")
elif dataset.features == FORMAT_MAPPING["instruction"]:
logging.info("Formatting dataset with instruction format")
return instructions_formatting_function(tokenizer)
if apply_chat_instruction_template:
return instructions_formatting_function(tokenizer)

return None
22 changes: 11 additions & 11 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def __init__(
dataset_num_proc: Optional[int] = None,
dataset_batch_size: int = 1000,
neftune_noise_alpha: Optional[float] = None,
apply_chat_instruction_template = True,
model_init_kwargs: Optional[Dict] = None,
dataset_kwargs: Optional[Dict] = None,
):
Expand Down Expand Up @@ -250,7 +251,7 @@ def make_inputs_require_grad(module, input, output):
if formatting_func is None and dataset_text_field is None:
# check if dataset has ChatML format or instruction format and is supported
# if not stays #None
formatting_func = get_formatting_func_from_dataset(train_dataset, tokenizer)
formatting_func = get_formatting_func_from_dataset(train_dataset, tokenizer, apply_chat_instruction_template)

requires_input_output_keys = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This variable name should probably be changed - it doesn't make sense if the keys are prompt / completion

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ya I changed it below and left that one place, thanks for catching it

if not packing:
Expand Down Expand Up @@ -373,7 +374,7 @@ def _prepare_dataset(
num_of_sequences,
chars_per_token,
remove_unused_columns=True,
requires_input_output_keys=False,
requires_prompt_completion_keys=False,
append_concat_token=True,
add_special_tokens=True,
):
Expand All @@ -393,7 +394,7 @@ def _prepare_dataset(
formatting_func,
add_special_tokens,
remove_unused_columns,
requires_input_output_keys,
requires_prompt_completion_keys,
)

else:
Expand All @@ -418,14 +419,13 @@ def _prepare_non_packed_dataloader(
formatting_func=None,
add_special_tokens=True,
remove_unused_columns=True,
requires_input_output_keys=False,
requires_prompt_completion_keys=False,
):
use_formatting_func = formatting_func is not None and dataset_text_field is None
self._dataset_sanity_checked = False

# TODO : fix how EOS tokens are handled
# Inspired from https://github.com/allenai/open-instruct/blob/main/open_instruct/finetune.py#L266
def tokenize_input_output(element):
def tokenize_prompt_completion(element):

# It is difficult to add special tokens here, as separator / EOS tokens that may be added while tokenizing
# input texts may differ from concatenated text, making masking on input length incorrect.
Expand All @@ -437,7 +437,7 @@ def tokenize_input_output(element):
)

new_source = []
for (input_element, output_element) in zip(element['input'], element['output']):
for (input_element, output_element) in zip(element['prompt'], element['completion']):
if not input_element.endswith((' ', '\n', '\t')) and not output_element.startswith((' ', '\n', '\t')):
new_source.append(input_element + ' ' + output_element)
else:
Expand Down Expand Up @@ -491,15 +491,15 @@ def tokenize(element):
f"inspect dataset other columns (in this case {extra_columns}), you can subclass `DataCollatorForLanguageModeling` in case you used the default collator and create your own data collator in order to inspect the unused dataset columns."
)

if requires_input_output_keys:
if "input" in dataset.column_names and "output" in dataset.column_names:
if requires_prompt_completion_keys:
if "prompt" in dataset.column_names and "completion" in dataset.column_names:
# TODO: if we execute this input path, it is expected that we are using a seq2seq
# collator. If that is the case, the tokenizer should had a pad_token; this is set
# to eos automatically if it's unset and no tokenizer is provided, but we should
# properly handle if a tokenizer with no padding token is given.
tokenize_func = tokenize_input_output
tokenize_func = tokenize_prompt_completion
else:
raise KeyError("Missing input / output keys")
raise KeyError("Missing prompt / completion keys")
else:
tokenize_func = tokenize

Expand Down