Skip to content

Commit e08d8fe

Browse files
committed
Account for possible empty sequences
1 parent 5caa917 commit e08d8fe

File tree

2 files changed

+21
-1
lines changed

2 files changed

+21
-1
lines changed

tests/test_data_utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1187,6 +1187,19 @@ def test_default_no_split(self):
11871187
dataset = pack_dataset(dataset, seq_length, strategy="bfd")
11881188
assert dataset.to_dict() == expected_output
11891189

1190+
def test_with_empty_sequences(self):
1191+
examples = {
1192+
"input_ids": [[1, 2], [], [3, 4, 5], [], [6]],
1193+
}
1194+
dataset = Dataset.from_dict(examples)
1195+
seq_length = 4
1196+
expected_output = {
1197+
"input_ids": [[3, 4, 5, 6], [1, 2]],
1198+
"seq_lengths": [[3, 1], [2]],
1199+
}
1200+
dataset = pack_dataset(dataset, seq_length, strategy="bfd_split")
1201+
assert dataset.to_dict() == expected_output
1202+
11901203

11911204
class TestTruncateExamples(TrlTestCase):
11921205
def test_with_dataset(self):

trl/data_utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -717,10 +717,17 @@ def _pack_bfd(
717717
_check_if_columns_can_be_packed(columns)
718718
assert len(columns) > 0
719719

720+
lengths = pc.list_value_length(columns[0])
721+
722+
# Filter out empty sequences
723+
non_empty_mask = pc.greater(lengths, 0)
724+
columns = [pc.filter(column, non_empty_mask) for column in columns]
725+
lengths = pc.filter(lengths, non_empty_mask)
726+
720727
if on_seq_length_overflow == "truncate":
721728
columns = [pc.list_slice(column, 0, seq_length) for column in columns]
722729
elif on_seq_length_overflow == "split":
723-
lengths = pc.list_value_length(columns[0]).to_numpy()
730+
lengths = lengths.to_numpy()
724731
# Split the sequences longer than `seq_length` into chunks (of length `seq_length` or less) while respecting sequence boundaries
725732
num_fragments = np.ceil(lengths / seq_length).astype(int)
726733
offsets = np.arange(np.sum(num_fragments) + 1, dtype=columns[0].offsets.type.to_pandas_dtype()) * seq_length

0 commit comments

Comments
 (0)