Skip to content

Commit 15a8367

Browse files
committed
Fix BFD formatting
1 parent 86e89a2 commit 15a8367

File tree

2 files changed

+5
-0
lines changed

2 files changed

+5
-0
lines changed

tests/test_data_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1111,7 +1111,10 @@ def test_with_dataset(self):
11111111
"seq_lengths": [[4], [3, 1]],
11121112
}
11131113
dataset = pack_dataset(dataset, seq_length, strategy="bfd")
1114+
expected_format = dataset.format
11141115
assert dataset.to_dict() == expected_output
1116+
assert "seq_lengths" in expected_format["columns"]
1117+
expected_format["columns"].remove("seq_lengths")
11151118
assert format == dataset.format
11161119

11171120
def test_with_iterable_dataset(self):

trl/data_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -894,6 +894,8 @@ def pack_dataset(
894894
)
895895
else: # PackingStrategy.WRAPPED
896896
dataset = dataset.map(_pack_wrapped, batched=True, fn_kwargs={"seq_length": seq_length}, **map_kwargs)
897+
if strategy in {PackingStrategy.BFD, PackingStrategy.BFD_SPLIT} and "columns" in format:
898+
format["columns"].append("seq_lengths")
897899
dataset = dataset.with_format(**format)
898900
return dataset
899901

0 commit comments

Comments
 (0)