Skip to content
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions docs/source/data_utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,15 @@
## unpair_preference_dataset

[[autodoc]] unpair_preference_dataset

## truncate_dataset

[[autodoc]] truncate_dataset

## pack_dataset

[[autodoc]] pack_dataset

## PackingStrategy

[[autodoc]] PackingStrategy
45 changes: 37 additions & 8 deletions tests/test_data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from transformers import AutoProcessor, AutoTokenizer, is_vision_available

from trl.data_utils import (
PackingStrategy,
apply_chat_template,
extract_prompt,
is_conversational,
Expand Down Expand Up @@ -1056,57 +1057,79 @@ def test_with_dataset(self):
"attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]],
}
dataset = Dataset.from_dict(examples)
dataset = dataset.with_format("numpy", dtype="float32")
format = dataset.format
seq_length = 3
expected_output = {
"input_ids": [[1, 2, 3], [4, 5, 6], [7, 8]],
"attention_mask": [[0, 1, 1], [0, 0, 1], [1, 1]],
}
dataset = pack_dataset(dataset, seq_length, strategy="wrapped")
assert dataset.to_dict() == expected_output
assert format == dataset.format

def test_with_iterable_dataset(self):
examples = {
"input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]],
"attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]],
}
dataset = Dataset.from_dict(examples).to_iterable_dataset()
dataset = dataset.with_format("numpy")
formatting = dataset._formatting
seq_length = 3
expected_output = {
"input_ids": [[1, 2, 3], [4, 5, 6], [7, 8]],
"attention_mask": [[0, 1, 1], [0, 0, 1], [1, 1]],
}
dataset = pack_dataset(dataset, seq_length, strategy="wrapped")
num_examples = len(examples[next(iter(examples))])
assert next(iter(dataset.batch(batch_size=num_examples))) == expected_output
assert next(iter(dataset.with_format(None).batch(batch_size=num_examples))) == expected_output
assert formatting == dataset._formatting


class TestPackingStrategy(TrlTestCase):
def test_aliases(self):
assert PackingStrategy("bfd-split") is PackingStrategy.BFD_SPLIT
assert PackingStrategy("bfd-truncate") is PackingStrategy.BFD

def test_missing_value_raises_value_error(self):
with pytest.raises(ValueError, match="not a valid PackingStrategy"):
PackingStrategy("missing")


class TestPackDatasetBfd(TrlTestCase):
def test_simple(self):
def test_with_dataset(self):
examples = {
"input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]],
}
dataset = Dataset.from_dict(examples)
dataset = dataset.with_format("numpy", dtype="float32")
format = dataset.format
seq_length = 4
expected_output = {
"input_ids": [[4, 5, 6, 7], [1, 2, 3, 8]],
"seq_lengths": [[4], [3, 1]],
}
dataset = pack_dataset(dataset, seq_length, strategy="bfd")
assert dataset.to_dict() == expected_output
assert format == dataset.format

def test_with_iterable_dataset(self):
examples = {
"input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]],
}
dataset = Dataset.from_dict(examples).to_iterable_dataset()
dataset = dataset.with_format("numpy")
formatting = dataset._formatting
seq_length = 4
expected_output = {
"input_ids": [[4, 5, 6, 7], [1, 2, 3, 8]],
"seq_lengths": [[4], [3, 1]],
}
dataset = pack_dataset(dataset, seq_length, strategy="bfd")
num_examples = len(examples[next(iter(examples))])
assert next(iter(dataset.batch(batch_size=num_examples))) == expected_output
assert next(iter(dataset.with_format(None).batch(batch_size=num_examples))) == expected_output
assert formatting == dataset._formatting

def test_with_overlong_0(self):
examples = {
Expand All @@ -1118,7 +1141,7 @@ def test_with_overlong_0(self):
"input_ids": [[1, 2, 3, 4], [8, 9, 10, 11], [6, 7, 5, 12]],
"seq_lengths": [[4], [4], [2, 1, 1]],
}
dataset = pack_dataset(dataset, seq_length, strategy="bfd-requeue")
dataset = pack_dataset(dataset, seq_length, strategy="bfd_split")
assert dataset.to_dict() == expected_output

def test_with_overlong_two_coluns(self):
Expand All @@ -1133,7 +1156,7 @@ def test_with_overlong_two_coluns(self):
"col2": [[-1, 2, -3, 4], [-13, 14, -15, 16], [-7, 8, -9], [10, -11, 12], [-5, 6]],
"seq_lengths": [[4], [4], [3], [3], [2]],
}
dataset = pack_dataset(dataset, seq_length, strategy="bfd-requeue")
dataset = pack_dataset(dataset, seq_length, strategy="bfd_split")
assert dataset.to_dict() == expected_output

def test_with_non_power_of_2(self):
Expand All @@ -1146,10 +1169,10 @@ def test_with_non_power_of_2(self):
"input_ids": [[1, 2, 3, 4, 5], [7, 8, 9, 10, 6], [11, 12, 13]],
"seq_lengths": [[5], [4, 1], [3]],
}
dataset = pack_dataset(dataset, seq_length, strategy="bfd-requeue")
dataset = pack_dataset(dataset, seq_length, strategy="bfd_split")
assert dataset.to_dict() == expected_output

def test_default_no_requeue(self):
def test_default_no_split(self):
"""Test default 'bfd' strategy for SFT datasets (truncates overflow)."""
examples = {
"input_ids": [[1, 2, 3, 4, 5], [6, 7], [8, 9, 10, 11], [12]],
Expand All @@ -1172,28 +1195,34 @@ def test_with_dataset(self):
"attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]],
}
dataset = Dataset.from_dict(examples)
dataset = dataset.with_format("numpy", dtype="float32")
format = dataset.format
max_length = 2
expected_output = {
"input_ids": [[1, 2], [4, 5], [8]],
"attention_mask": [[0, 1], [0, 0], [1]],
}
dataset = truncate_dataset(dataset, max_length)
assert dataset.to_dict() == expected_output
assert format == dataset.format

def test_with_iterable_dataset(self):
examples = {
"input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]],
"attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]],
}
dataset = Dataset.from_dict(examples).to_iterable_dataset()
dataset = dataset.with_format("numpy")
formatting = dataset._formatting
max_length = 2
expected_output = {
"input_ids": [[1, 2], [4, 5], [8]],
"attention_mask": [[0, 1], [0, 0], [1]],
}
dataset = truncate_dataset(dataset, max_length)
num_examples = len(examples[next(iter(examples))])
assert next(iter(dataset.batch(batch_size=num_examples))) == expected_output
assert next(iter(dataset.with_format(None).batch(batch_size=num_examples))) == expected_output
assert formatting == dataset._formatting

def test_with_extra_column(self):
examples = {
Expand Down
2 changes: 2 additions & 0 deletions trl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
_import_structure = {
"chat_template_utils": ["add_response_schema", "clone_chat_template", "get_training_chat_template"],
"data_utils": [
"PackingStrategy",
"apply_chat_template",
"extract_prompt",
"is_conversational",
Expand Down Expand Up @@ -72,6 +73,7 @@
if TYPE_CHECKING:
from .chat_template_utils import add_response_schema, clone_chat_template, get_training_chat_template
from .data_utils import (
PackingStrategy,
apply_chat_template,
extract_prompt,
is_conversational,
Expand Down
Loading