Skip to content

Commit 1772b9e

Browse files
committed
Make padding more generic
1 parent 240abf0 commit 1772b9e

File tree

2 files changed

+97
-40
lines changed

2 files changed

+97
-40
lines changed

src/forge/data/collate.py

Lines changed: 47 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -16,56 +16,63 @@ def collate_padded(batch: list[dict[str, Any]]) -> dict[str, Any]:
1616
"""
1717
Collate function that pads sequences to the longest sample in the batch.
1818
19-
Pads 'tokens' with 0 and 'labels' with CROSS_ENTROPY_IGNORE_IDX (-100).
20-
Non-tensor fields (like metrics) are collected into lists and flattened
21-
if all items are lists.
19+
Handles any tensor keys by padding to the longest
20+
sequence for that key. Uses 0 as default padding value, and
21+
CROSS_ENTROPY_IGNORE_IDX (-100) for 'labels' keys.
22+
23+
Non-tensor fields are collected into lists. The 'metrics' field is
24+
special-cased to be flattened (extended) rather than nested.
2225
2326
Args:
24-
batch: List of samples, each containing 'tokens' and 'labels' tensors
27+
batch: List of samples, each containing tensor and non-tensor fields
2528
2629
Returns:
27-
Batched dict with padded tensors
30+
Batched dict with padded tensors and collected non-tensor fields
31+
32+
Raises:
33+
ValueError: If all samples do not have the same keys
2834
"""
2935
if not batch:
3036
return {}
3137

32-
# Find max length in batch
33-
max_len = max(sample["tokens"].size(0) for sample in batch)
38+
# Verify all samples have the same keys
39+
first_sample_keys = batch[0].keys()
40+
for sample in batch:
41+
if sample.keys() != first_sample_keys:
42+
raise ValueError(
43+
f"All samples must have the same keys. Expected {first_sample_keys}, got {sample.keys()}"
44+
)
3445

35-
# Initialize lists for batched tensors
36-
tokens_list = []
37-
labels_list = []
46+
collated = {}
3847

39-
# Pad each sample to max_len
40-
for sample in batch:
41-
seq_len = sample["tokens"].size(0)
42-
pad_len = max_len - seq_len
43-
44-
# Pad tokens with 0
45-
padded_tokens = F.pad(sample["tokens"], (0, pad_len), value=0)
46-
tokens_list.append(padded_tokens)
47-
48-
# Pad labels with CROSS_ENTROPY_IGNORE_IDX (-100)
49-
padded_labels = F.pad(
50-
sample["labels"], (0, pad_len), value=CROSS_ENTROPY_IGNORE_IDX
51-
)
52-
labels_list.append(padded_labels)
53-
54-
# Stack into batch
55-
result = {
56-
"tokens": torch.stack(tokens_list),
57-
"labels": torch.stack(labels_list),
58-
}
59-
60-
# Collect non-tensor fields (like metrics)
61-
for key in batch[0].keys():
62-
if key not in ["tokens", "labels"]:
63-
result[key] = [sample[key] for sample in batch]
64-
# Flatten if all are lists
65-
if all(isinstance(item, list) for item in result[key]):
66-
result[key] = [item for sublist in result[key] for item in sublist]
67-
68-
return result
48+
for key in first_sample_keys:
49+
if isinstance(batch[0][key], torch.Tensor):
50+
# Find max length for this tensor key
51+
max_len = max(sample[key].size(0) for sample in batch)
52+
53+
# Determine padding value
54+
pad_value = CROSS_ENTROPY_IGNORE_IDX if key == "labels" else 0
55+
56+
# Pad each sample to max_len
57+
padded_tensors = []
58+
for sample in batch:
59+
seq_len = sample[key].size(0)
60+
pad_len = max_len - seq_len
61+
padded = F.pad(sample[key], (0, pad_len), value=pad_value)
62+
padded_tensors.append(padded)
63+
64+
# Stack into batch
65+
collated[key] = torch.stack(padded_tensors)
66+
elif key == "metrics":
67+
# Flatten metrics lists
68+
collated[key] = []
69+
for sample in batch:
70+
collated[key].extend(sample[key])
71+
else:
72+
# Collect other non-tensor fields as lists
73+
collated[key] = [sample[key] for sample in batch]
74+
75+
return collated
6976

7077

7178
def collate_packed(

tests/unit_tests/datasets/test_packed.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1125,3 +1125,53 @@ def test_metrics_flattened(self):
11251125
assert "metrics" in result
11261126
# Should be flattened from [[metric1, metric2], [metric3]] to [metric1, metric2, metric3]
11271127
assert len(result["metrics"]) == 3
1128+
1129+
def test_different_keys_error(self):
1130+
"""Test that different keys across samples raises ValueError"""
1131+
batch = [
1132+
{"tokens": torch.tensor([1, 2]), "labels": torch.tensor([3, 4])},
1133+
{"tokens": torch.tensor([5, 6]), "other_key": torch.tensor([7, 8])},
1134+
]
1135+
1136+
with pytest.raises(ValueError, match="All samples must have the same keys"):
1137+
collate_padded(batch)
1138+
1139+
def test_generic_tensor_handling(self):
1140+
"""Test that any tensor field gets padded correctly"""
1141+
batch = [
1142+
{
1143+
"tokens": torch.tensor([1, 2]),
1144+
"labels": torch.tensor([3, 4]),
1145+
"custom_tensor": torch.tensor([100, 200, 300]),
1146+
},
1147+
{
1148+
"tokens": torch.tensor([5, 6, 7, 8]),
1149+
"labels": torch.tensor([9, 10, 11, 12]),
1150+
"custom_tensor": torch.tensor([400]),
1151+
},
1152+
]
1153+
result = collate_padded(batch)
1154+
1155+
# Tokens padded to length 4
1156+
assert result["tokens"].shape == (2, 4)
1157+
torch.testing.assert_close(
1158+
result["tokens"], torch.tensor([[1, 2, 0, 0], [5, 6, 7, 8]])
1159+
)
1160+
1161+
# Labels padded to length 4 with CROSS_ENTROPY_IGNORE_IDX
1162+
assert result["labels"].shape == (2, 4)
1163+
torch.testing.assert_close(
1164+
result["labels"],
1165+
torch.tensor(
1166+
[
1167+
[3, 4, CROSS_ENTROPY_IGNORE_IDX, CROSS_ENTROPY_IGNORE_IDX],
1168+
[9, 10, 11, 12],
1169+
]
1170+
),
1171+
)
1172+
1173+
# Custom tensor padded to length 3 with 0
1174+
assert result["custom_tensor"].shape == (2, 3)
1175+
torch.testing.assert_close(
1176+
result["custom_tensor"], torch.tensor([[100, 200, 300], [400, 0, 0]])
1177+
)

0 commit comments

Comments
 (0)