Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions tests/test_data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1118,7 +1118,7 @@ def test_with_overlong_0(self):
"input_ids": [[1, 2, 3, 4], [8, 9, 10, 11], [6, 7, 5, 12]],
"seq_lengths": [[4], [4], [2, 1, 1]],
}
dataset = pack_dataset(dataset, seq_length, strategy="bfd-requeue")
dataset = pack_dataset(dataset, seq_length, strategy="bfd-split")
assert dataset.to_dict() == expected_output

def test_with_overlong_two_coluns(self):
Expand All @@ -1133,7 +1133,7 @@ def test_with_overlong_two_coluns(self):
"col2": [[-1, 2, -3, 4], [-13, 14, -15, 16], [-7, 8, -9], [10, -11, 12], [-5, 6]],
"seq_lengths": [[4], [4], [3], [3], [2]],
}
dataset = pack_dataset(dataset, seq_length, strategy="bfd-requeue")
dataset = pack_dataset(dataset, seq_length, strategy="bfd-split")
assert dataset.to_dict() == expected_output

def test_with_non_power_of_2(self):
Expand All @@ -1146,10 +1146,10 @@ def test_with_non_power_of_2(self):
"input_ids": [[1, 2, 3, 4, 5], [7, 8, 9, 10, 6], [11, 12, 13]],
"seq_lengths": [[5], [4, 1], [3]],
}
dataset = pack_dataset(dataset, seq_length, strategy="bfd-requeue")
dataset = pack_dataset(dataset, seq_length, strategy="bfd-split")
assert dataset.to_dict() == expected_output

def test_default_no_requeue(self):
def test_default_no_split(self):
"""Test default 'bfd' strategy for SFT datasets (truncates overflow)."""
examples = {
"input_ids": [[1, 2, 3, 4, 5], [6, 7], [8, 9, 10, 11], [12]],
Expand Down
213 changes: 91 additions & 122 deletions trl/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@
from collections import defaultdict, deque
from collections.abc import Callable, Sequence
from itertools import takewhile
from typing import Any, TypeVar
from typing import Any, Literal, TypeVar

import numpy as np
import pyarrow as pa
import pyarrow.compute as pc
import pyarrow.types
from datasets import Dataset, DatasetDict
from datasets import Dataset, DatasetDict, IterableDatasetDict
from transformers import PreTrainedTokenizerBase, ProcessorMixin


Expand Down Expand Up @@ -612,6 +612,29 @@ def maybe_extract_prompt(example: dict[str, list]) -> dict[str, list]:
return extract_prompt({"chosen": example["chosen"], "rejected": example["rejected"]})


def _get_dataset_format(dataset: DatasetType) -> dict[str, Any]:
if isinstance(dataset, (DatasetDict, IterableDatasetDict)):
dataset = dataset[next(iter(dataset))]
if isinstance(dataset, Dataset):
format = dataset.format
else:
format_type = dataset.formatting.format_type if dataset._formatting is not None else None
format = {"type": format_type}
return format


def _check_if_columns_can_be_packed(columns: list[pa.Array]):
first_column_offsets = None
for idx, column in enumerate(columns):
if not (pyarrow.types.is_list(column.type) or pyarrow.types.is_large_list(column.type)):
raise TypeError("Packing requires all columns to be lists of lists.")

if idx == 0:
first_column_offsets = column.offsets
elif first_column_offsets != column.offsets:
raise ValueError("All columns must have values of the same length.")


class _SegmentTree:
"""
A segment tree data structure that, when initialized as `_SegmentTree(maxval)`, efficiently finds the next larger
Expand Down Expand Up @@ -657,75 +680,38 @@ def search(self, val):
return self.tree[i]


def _pack_bfd(examples: pa.Table, seq_length: int, requeue_truncated_sequences: bool = False) -> pa.Table:
def _pack_bfd(
examples: pa.Table, seq_length: int, on_seq_length_overflow: Literal["truncate", "split"] = "truncate"
) -> pa.Table:
"""Pack sequences in a pyarrow Table using Best Fit Decreasing strategy."""
# Identify the list column and prepare all columns
columns = []
list_column_idx = None
for idx, column in enumerate(examples.columns):
if isinstance(column, pa.ChunkedArray):
column = column.combine_chunks()
if not (pyarrow.types.is_list(column.type) or pyarrow.types.is_large_list(column.type)):
raise TypeError("pack_dataset(bfd) requires all columns to be list-like.")
if list_column_idx is None:
list_column_idx = idx
columns.append(column)

assert list_column_idx is not None
list_column = columns[list_column_idx]
offsets = np.asarray(list_column.offsets)
values = list_column.values

# Split every list row into fragments of length <= seq_length (so long rows become multiple samples).
frag_lengths: list[int] = []
frag_info: list[tuple[int, int, int]] = [] # (row_idx, split_start, frag_len)
expanded_indices: list[int] = []
for row_idx, (row_start, row_end) in enumerate(zip(offsets[:-1], offsets[1:], strict=False)):
length = row_end - row_start
for split_start in range(0, length, seq_length):
frag_len = min(seq_length, length - split_start)
# When requeue_truncated_sequences is False, only keep the first fragment (truncate overflow)
if not requeue_truncated_sequences and split_start > 0:
continue
# Clamp the first fragment to seq_length when not re-queuing
if not requeue_truncated_sequences and frag_len > seq_length:
frag_len = seq_length
frag_lengths.append(frag_len)
frag_info.append((row_idx, split_start, frag_len))
expanded_indices.append(row_idx)

# Rebuild list columns with fragments
offsets_type = list_column.offsets.type
new_offsets = np.empty(len(frag_lengths) + 1, dtype=offsets_type.to_pandas_dtype())
new_offsets[0] = 0
new_offsets[1:] = np.cumsum(frag_lengths, dtype=offsets_type.to_pandas_dtype())
new_offsets_array = pa.array(new_offsets, type=offsets_type)

for idx, column in enumerate(columns):
if idx == list_column_idx:
slices = [
values.slice(offsets[row_idx] + split_start, frag_len) for row_idx, split_start, frag_len in frag_info
]
new_values = pa.concat_arrays(slices)
columns[idx] = type(column).from_arrays(new_offsets_array, new_values)
continue

column_offsets = np.asarray(column.offsets)
column_values = column.values
slices = []
for row_idx, split_start, frag_len in frag_info:
row_len = column_offsets[row_idx + 1] - column_offsets[row_idx]
if row_len < split_start + frag_len:
raise ValueError("List columns must have matching lengths when packing datasets.")
start = column_offsets[row_idx] + split_start
slices.append(column_values.slice(start, frag_len))
column_offsets_array = pa.array(new_offsets, type=column.offsets.type)
columns[idx] = type(column).from_arrays(column_offsets_array, pa.concat_arrays(slices))
columns = [column.chunks[0] for column in examples.combine_chunks().columns]
_check_if_columns_can_be_packed(columns)
assert len(columns) > 0

lengths = pc.list_value_length(columns[0]).to_numpy()

if on_seq_length_overflow == "truncate":
columns = [pc.list_slice(column, 0, seq_length) for column in columns]
elif on_seq_length_overflow == "split":
# Split the sequences longer than `seq_length` into chunks (of length `seq_length` or less) while respecting sequence boundaries
num_fragments = np.ceil(lengths / seq_length).astype(int)
offsets = np.arange(np.sum(num_fragments) + 1, dtype=columns[0].offsets.type.to_pandas_dtype()) * seq_length
# "Left-shift" the offsets to account for the last fragment of each original sequence possibly being shorter than `seq_length`
diff = np.zeros_like(offsets)
diff[np.cumsum(num_fragments)] = -lengths % seq_length
diff = np.cumsum(diff)
offsets -= diff
columns = [
type(column).from_arrays(offsets.astype(column.offsets.type.to_pandas_dtype()), column.values)
for column in columns
]
else:
raise ValueError(f"Invalid `on_seq_length_overflow`: {on_seq_length_overflow}. Use 'truncate' or 'split'.")

examples = pa.Table.from_arrays(columns, names=examples.column_names)
ids = np.arange(len(examples))
lengths = pc.list_value_length(examples[list_column_idx]).combine_chunks()
lengths = pc.list_value_length(columns[0])
examples = examples.append_column("seq_lengths", lengths) # Allows us to later construct `position_ids`
ids = np.arange(len(examples))
lengths = pc.make_struct(lengths, ids)
lengths = lengths.sort("descending", by=0)

Expand Down Expand Up @@ -771,35 +757,33 @@ def _pack_bfd(examples: pa.Table, seq_length: int, requeue_truncated_sequences:
columns = []
for column in examples.columns:
column = column.chunks[0]
if pa.types.is_list(column.type) or pa.types.is_large_list(column.type):
dtype = column.offsets.type.to_pandas_dtype()
column = type(column).from_arrays(offsets.astype(dtype), column.values)
assert pa.types.is_list(column.type) or pa.types.is_large_list(column.type)
dtype = column.offsets.type.to_pandas_dtype()
column = type(column).from_arrays(offsets.astype(dtype), column.values)
columns.append(column)
return pa.Table.from_arrays(columns + [lengths], names=examples.column_names + ["seq_lengths"])


def _pack_wrapped(examples: pa.Table, seq_length: int) -> pa.Table:
"""Pack sequences in a pyarrow Table using a wrapped strategy."""
columns = []
for column in examples.columns:
if pyarrow.types.is_list(column.type) or pyarrow.types.is_large_list(column.type):
if isinstance(column, pa.ChunkedArray):
column = column.combine_chunks()
offsets, values = column.offsets, column.values
values = values[offsets[0].as_py() : offsets[-1].as_py()]
num_elements = len(values)
dtype = offsets.type.to_pandas_dtype() # np.int32 or np.int64
offsets = np.arange(0, num_elements, seq_length, dtype=dtype)
offsets = np.concatenate((offsets, [num_elements]))
column = type(column).from_arrays(offsets, values)
columns.append(column)
columns = [column.chunks[0] for column in examples.combine_chunks().columns]
_check_if_columns_can_be_packed(columns)
offsets, values = columns[0].offsets, columns[0].values
values = values[offsets[0].as_py() : offsets[-1].as_py()]
num_elements = len(values)
offsets = np.arange(0, num_elements, seq_length, dtype=columns[0].offsets.type.to_pandas_dtype())
offsets = np.concatenate((offsets, [num_elements]))
columns = [
type(column).from_arrays(offsets.astype(column.offsets.type.to_pandas_dtype()), column.values)
for column in columns
]
return pa.Table.from_arrays(columns, names=examples.column_names)


def pack_dataset(
dataset: DatasetType,
seq_length: int,
strategy: str = "bfd",
strategy: Literal["bfd", "bfd-split", "wrapped"] = "bfd",
map_kwargs: dict[str, Any] | None = None,
) -> DatasetType:
r"""
Expand All @@ -811,13 +795,13 @@ def pack_dataset(
seq_length (`int`):
Target sequence length to pack to.
strategy (`str`, *optional*, defaults to `"bfd"`):
Packing strategy to use. Can be one of:
Packing strategy to use. Can be either:

- `"bfd"` (Best Fit Decreasing): Preserves sequence boundaries and truncates sequences that exceed
`seq_length`, discarding overflow tokens. Ideal for SFT and conversational datasets where maintaining
conversation structure is important.
- `"bfd-requeue"`: Similar to `"bfd"` but re-queues truncated overflow tokens for packing into other
sequences. Prevents token loss for pre-training or long documents, but may break conversation structure
- `"bfd-split"`: Similar to `"bfd"` but splits overflow sequences for packing into other
examples. Prevents token loss for pre-training or long documents, but may break conversation structure
in SFT datasets.
- `"wrapped"`: Faster but more aggressive. Ignores sequence boundaries and will cut sequences in the middle
to completely fill each packed sequence with data.
Expand Down Expand Up @@ -845,8 +829,8 @@ def pack_dataset(
'attention_mask': [[1, 1, 1, 0], [1, 1, 0, 1], [1, 0]],
'seq_lengths': [[4], [3, 1], [2]]}

>>> # "bfd-requeue" strategy: preserves all tokens
>>> packed_dataset = pack_dataset(dataset, seq_length=4, strategy="bfd-requeue")
>>> # "bfd-split" strategy: preserves all tokens
>>> packed_dataset = pack_dataset(dataset, seq_length=4, strategy="bfd-split")
>>> packed_dataset[:]
{'input_ids': [[1, 2, 3, 4], [8, 9, 10, 5], [6, 7, 11]],
'attention_mask': [[1, 1, 1, 0], [1, 1, 0, 0], [1, 0, 1]],
Expand All @@ -855,27 +839,27 @@ def pack_dataset(
"""
if map_kwargs is None:
map_kwargs = {}
# Fast packing with pyarrow
format = _get_dataset_format(dataset)
dataset = dataset.with_format("arrow")
if strategy == "bfd":
if strategy in {"bfd", "bfd-truncate"}:
dataset = dataset.map(
_pack_bfd,
batched=True,
fn_kwargs={"seq_length": seq_length, "requeue_truncated_sequences": False},
fn_kwargs={"seq_length": seq_length, "on_seq_length_overflow": "truncate"},
**map_kwargs,
)
elif strategy == "bfd-requeue":
elif strategy in {"bfd-split", "bfd-requeue"}:
dataset = dataset.map(
_pack_bfd,
batched=True,
fn_kwargs={"seq_length": seq_length, "requeue_truncated_sequences": True},
fn_kwargs={"seq_length": seq_length, "on_seq_length_overflow": "split"},
**map_kwargs,
)
elif strategy == "wrapped":
dataset = dataset.map(_pack_wrapped, batched=True, fn_kwargs={"seq_length": seq_length}, **map_kwargs)
else:
raise ValueError(f"Invalid packing strategy: {strategy}. Use 'bfd', 'bfd-requeue', or 'wrapped'.")
dataset = dataset.with_format(None)
raise ValueError(f"Invalid packing strategy: {strategy}. Use 'bfd', 'bfd-split', or 'wrapped'.")
dataset = dataset.with_format(**format)
return dataset


Expand Down Expand Up @@ -911,34 +895,19 @@ def truncate_dataset(dataset: DatasetType, max_length: int, map_kwargs: dict[str
"""
if map_kwargs is None:
map_kwargs = {}
if isinstance(dataset, Dataset):
# Fast truncation with pyarrow
def truncate(examples):
truncated_columns = []
for column in examples.columns:
if pyarrow.types.is_list(column.type) or pyarrow.types.is_large_list(column.type):
column = pc.list_slice(column, 0, max_length)
truncated_columns.append(column)
return pa.Table.from_arrays(truncated_columns, names=examples.column_names)

dataset = dataset.with_format("arrow")
dataset = dataset.map(truncate, batched=True, **map_kwargs)
dataset = dataset.with_format(None)
else:

def truncate(examples):
truncated_examples = {}
for key, column in examples.items():
if column and isinstance(column[0], list):
column = [val[:max_length] for val in column]
truncated_examples[key] = column
return truncated_examples
def truncate(examples):
truncated_columns = []
for column in examples.columns:
if pyarrow.types.is_list(column.type) or pyarrow.types.is_large_list(column.type):
column = pc.list_slice(column, 0, max_length)
truncated_columns.append(column)
return pa.Table.from_arrays(truncated_columns, names=examples.column_names)

dataset = dataset.map(
truncate,
batched=True,
**map_kwargs,
)
format = _get_dataset_format(dataset)
dataset = dataset.with_format("arrow")
dataset = dataset.map(truncate, batched=True, **map_kwargs)
dataset = dataset.with_format(**format)
return dataset


Expand Down
8 changes: 4 additions & 4 deletions trl/trainer/sft_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from dataclasses import dataclass, field
from typing import Any
from typing import Any, Literal

from transformers import TrainingArguments

Expand Down Expand Up @@ -190,13 +190,13 @@ class SFTConfig(_BaseConfig):
"and reduce padding. Uses `max_length` to define sequence length."
},
)
packing_strategy: str = field(
packing_strategy: Literal["bfd", "bfd-split", "wrapped"] = field(
default="bfd",
metadata={
"help": "Strategy for packing sequences. Can be `'bfd'` (best-fit decreasing, truncates overflow), "
"`'bfd-requeue'` (best-fit decreasing, re-queues overflow tokens), or `'wrapped'` (aggressive, cuts "
"`'bfd-split'` (best-fit decreasing, splits overflow sequences), or `'wrapped'` (aggressive, cuts "
"mid-sequence).",
"choices": ["bfd", "bfd-requeue", "wrapped"],
"choices": ["bfd", "bfd-split", "wrapped"],
},
)
padding_free: bool = field(
Expand Down
Loading