Skip to content

Commit 17e8060

Browse files
authored
📦 Support for packing tokenized datasets for SFT (huggingface#2011)
* feat: add support for packing tokenized datasetS Signed-off-by: Mehant Kammakomati <[email protected]> * fix: address review comments Signed-off-by: Mehant Kammakomati <[email protected]> * feat: add tests for pretokenized dataset packing Signed-off-by: Mehant Kammakomati <[email protected]> --------- Signed-off-by: Mehant Kammakomati <[email protected]>
1 parent 163695e commit 17e8060

File tree

3 files changed

+99
-4
lines changed

3 files changed

+99
-4
lines changed

tests/test_sft_trainer.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ def formatting_prompts_func(example):
3939
return text
4040

4141

42+
def formatting_func_for_pretokenized(example):
43+
return example["input_ids"]
44+
45+
4246
def formatting_prompts_func_batched(example):
4347
output_text = []
4448
for i, question in enumerate(example["question"]):
@@ -93,6 +97,17 @@ def setUp(self):
9397
],
9498
}
9599
)
100+
self.dummy_tokenized_dataset = Dataset.from_dict(
101+
{
102+
"input_ids": [
103+
self.tokenizer.encode(
104+
"TRL is a library to post-train LLMs and diffusion models with methods such as Supervised Fine-tuning (SFT), Proximal Policy Optimization (PPO), and Direct Preference Optimization (DPO)."
105+
)
106+
]
107+
* 10
108+
}
109+
)
110+
96111
self.conversational_lm_dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling")
97112
self.standard_prompt_completion_dataset = load_dataset(
98113
"trl-internal-testing/zen", "standard_prompt_completion"
@@ -158,6 +173,42 @@ def setUp(self):
158173
num_of_sequences=16,
159174
)
160175

176+
self.train_dataset_from_pretokenized = ConstantLengthDataset(
177+
self.tokenizer,
178+
self.dummy_tokenized_dataset,
179+
seq_length=16,
180+
num_of_sequences=16,
181+
formatting_func=formatting_func_for_pretokenized,
182+
)
183+
184+
self.eval_dataset_from_pretokenized = ConstantLengthDataset(
185+
self.tokenizer,
186+
self.dummy_tokenized_dataset,
187+
seq_length=16,
188+
num_of_sequences=16,
189+
formatting_func=formatting_func_for_pretokenized,
190+
)
191+
192+
def test_constant_length_dataset_with_pretokenized_data(self):
193+
constant_len_dataset = ConstantLengthDataset(
194+
self.tokenizer,
195+
self.dummy_tokenized_dataset,
196+
formatting_func=formatting_func_for_pretokenized,
197+
)
198+
199+
assert len(constant_len_dataset) == len(self.dummy_tokenized_dataset)
200+
assert len(constant_len_dataset) > 0
201+
202+
for example in constant_len_dataset:
203+
assert "input_ids" in example
204+
assert "labels" in example
205+
206+
assert len(example["input_ids"]) == constant_len_dataset.seq_length
207+
assert len(example["labels"]) == constant_len_dataset.seq_length
208+
209+
decoded_text = self.tokenizer.decode(example["input_ids"])
210+
assert ("TRL" in decoded_text) and ("(DPO)" in decoded_text)
211+
161212
def test_constant_length_dataset(self):
162213
formatted_dataset = ConstantLengthDataset(
163214
self.tokenizer,
@@ -236,6 +287,34 @@ def test_sft_trainer(self):
236287

237288
self.assertIn("model.safetensors", os.listdir(tmp_dir + "/checkpoint-2"))
238289

290+
def test_sft_trainer_with_pretokenzied_data_packing(self):
291+
with tempfile.TemporaryDirectory() as tmp_dir:
292+
training_args = SFTConfig(
293+
output_dir=tmp_dir,
294+
dataloader_drop_last=True,
295+
eval_strategy="steps",
296+
max_steps=4,
297+
eval_steps=2,
298+
save_steps=2,
299+
per_device_train_batch_size=2,
300+
packing=True,
301+
report_to="none",
302+
)
303+
304+
trainer = SFTTrainer(
305+
model=self.model_id,
306+
args=training_args,
307+
train_dataset=self.train_dataset_from_pretokenized,
308+
eval_dataset=self.eval_dataset_from_pretokenized,
309+
)
310+
311+
trainer.train()
312+
313+
assert trainer.state.log_history[(-1)]["train_loss"] is not None
314+
assert trainer.state.log_history[0]["eval_loss"] is not None
315+
316+
assert "model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2")
317+
239318
def test_sft_trainer_uncorrect_data(self):
240319
with tempfile.TemporaryDirectory() as tmp_dir:
241320
# Shoud work as SFTTrainer natively supports conversational lm dataset

trl/trainer/sft_trainer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,11 @@ def _prepare_dataset(
367367
"You passed a dataset that is already processed (contains an `input_ids` field) together with a valid formatting function. Therefore `formatting_func` will be ignored."
368368
)
369369

370-
return dataset
370+
def formatting_func(x):
371+
return x["input_ids"]
372+
373+
if not packing:
374+
return dataset
371375

372376
# check if torch dataset / dataloader and do nothing
373377
# see https://github.com/huggingface/trl/pull/1468 for why datasets.IterableDataset needs a separate check

trl/trainer/utils.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from importlib.metadata import version
2222
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
2323

24+
import datasets
2425
import numpy as np
2526
import pandas as pd
2627
import torch
@@ -627,6 +628,14 @@ def __init__(
627628
"The passed formatting_func has more than one argument. Usually that function should have a single argument `example`"
628629
" which corresponds to the dictionary returned by each element of the dataset. Make sure you know what you are doing."
629630
)
631+
self.pretokenized = False
632+
column_names = (
633+
dataset.column_names if isinstance(dataset, (datasets.Dataset, datasets.IterableDataset)) else None
634+
)
635+
if column_names is not None and "input_ids" in column_names:
636+
self.pretokenized = True
637+
# since the dataset is tokenized, the unit of buffer size should be tokens
638+
self.max_buffer_size = seq_length * num_of_sequences
630639

631640
def __len__(self):
632641
return len(self.dataset)
@@ -651,9 +660,12 @@ def __iter__(self):
651660
break
652661
if self.shuffle:
653662
random.shuffle(buffer)
654-
tokenized_inputs = self.tokenizer(buffer, add_special_tokens=self.add_special_tokens, truncation=False)[
655-
"input_ids"
656-
]
663+
if self.pretokenized:
664+
tokenized_inputs = buffer
665+
else:
666+
tokenized_inputs = self.tokenizer(
667+
buffer, add_special_tokens=self.add_special_tokens, truncation=False
668+
)["input_ids"]
657669
all_token_ids = []
658670
for tokenized_input in tokenized_inputs:
659671
if self.append_concat_token:

0 commit comments

Comments
 (0)