Skip to content

Commit b7652a9

Browse files
committed
Normal padding (no packing) in SFT
1 parent ad346cd commit b7652a9

File tree

6 files changed

+195
-21
lines changed

6 files changed

+195
-21
lines changed

apps/sft/llama3_8b.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ lr_scheduler:
2727
warmup_steps: 200
2828

2929
training:
30-
local_batch_size: 1
30+
local_batch_size: 8
3131
seq_len: 2048
3232
max_norm: 1.0
3333
steps: 1000

apps/sft/main.py

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,13 @@
1616
import math
1717
import os
1818
import sys
19-
from functools import partial
2019
from typing import Any
2120

2221
import torch
2322

2423
import torchtitan.experiments.forge.train_spec as forge_train_spec
2524
from forge.controller import ForgeActor
26-
from forge.data.collate import collate_packed
27-
from forge.data.datasets.packed import PackedDataset, TextPacker
25+
from forge.data.collate import collate_padded
2826
from forge.data.datasets.sft_dataset import AlpacaToMessages, sft_iterable_dataset
2927
from forge.data.tokenizer import HuggingFaceModelTokenizer
3028
from forge.data.utils import StopAfterOneEpoch
@@ -197,25 +195,12 @@ def setup_data(self, dataset_configs: list[dict]) -> StatefulDataLoader:
197195
**dataset_config,
198196
)
199197

200-
packer = TextPacker(padding_idx=0)
201-
dataset = PackedDataset(
202-
dataset=dataset,
203-
packer=packer,
204-
target_tokens_per_pack=self.job_config.training.seq_len, # TODO: get this from model
205-
)
206-
207198
dataloader = StatefulDataLoader(
208199
dataset=dataset,
209200
batch_size=self.job_config.training.local_batch_size,
210-
collate_fn=partial(
211-
collate_packed, mask_fn=packer.create_block_mask, device=self.device
212-
),
201+
collate_fn=collate_padded,
213202
)
214203

215-
# Ultimately we probably want something like this
216-
# packer = build_packing_strategy(packing_config)
217-
# dataset = build_dataset(dataset_config)
218-
# dataloader = build_dataloader(dataloader_config, dataset, packer)
219204
return dataloader
220205

221206
def forward_backward(

apps/sft/qwen3_8b.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ lr_scheduler:
2626
warmup_steps: 200
2727

2828
training:
29-
local_batch_size: 1
29+
local_batch_size: 8
3030
seq_len: 2048
3131
max_norm: 1.0
3232
steps: 1000

src/forge/data/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@
33
#
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
6-
from .collate import collate_packed
6+
from .collate import collate_packed, collate_padded
77
from .metric_transform import DefaultDatasetMetricTransform, MetricTransform
88
from .utils import CROSS_ENTROPY_IGNORE_IDX
99

1010
__all__ = [
1111
"collate_packed",
12+
"collate_padded",
1213
"CROSS_ENTROPY_IGNORE_IDX",
1314
"MetricTransform",
1415
"DefaultDatasetMetricTransform",

src/forge/data/collate.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,63 @@
77
from typing import Any, Callable
88

99
import torch
10+
import torch.nn.functional as F
11+
12+
from forge.data.utils import CROSS_ENTROPY_IGNORE_IDX
13+
14+
15+
def collate_padded(batch: list[dict[str, Any]]) -> dict[str, Any]:
16+
"""
17+
Collate function that pads sequences to the longest sample in the batch.
18+
19+
Pads 'tokens' with 0 and 'labels' with CROSS_ENTROPY_IGNORE_IDX (-100).
20+
Non-tensor fields (like metrics) are collected into lists and flattened
21+
if all items are lists.
22+
23+
Args:
24+
batch: List of samples, each containing 'tokens' and 'labels' tensors
25+
26+
Returns:
27+
Batched dict with padded tensors
28+
"""
29+
if not batch:
30+
return {}
31+
32+
# Find max length in batch
33+
max_len = max(sample["tokens"].size(0) for sample in batch)
34+
35+
# Initialize lists for batched tensors
36+
tokens_list = []
37+
labels_list = []
38+
39+
# Pad each sample to max_len
40+
for sample in batch:
41+
seq_len = sample["tokens"].size(0)
42+
pad_len = max_len - seq_len
43+
44+
# Pad tokens with 0
45+
padded_tokens = F.pad(sample["tokens"], (0, pad_len), value=0)
46+
tokens_list.append(padded_tokens)
47+
48+
# Pad labels with CROSS_ENTROPY_IGNORE_IDX (-100)
49+
padded_labels = F.pad(sample["labels"], (0, pad_len), value=CROSS_ENTROPY_IGNORE_IDX)
50+
labels_list.append(padded_labels)
51+
52+
# Stack into batch
53+
result = {
54+
"tokens": torch.stack(tokens_list),
55+
"labels": torch.stack(labels_list),
56+
}
57+
58+
# Collect non-tensor fields (like metrics)
59+
for key in batch[0].keys():
60+
if key not in ["tokens", "labels"]:
61+
result[key] = [sample[key] for sample in batch]
62+
# Flatten if all are lists
63+
if all(isinstance(item, list) for item in result[key]):
64+
result[key] = [item for sublist in result[key] for item in sublist]
65+
66+
return result
1067

1168

1269
def collate_packed(

tests/unit_tests/datasets/test_packed.py

Lines changed: 132 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
import pytest
1414
import torch
1515

16-
from forge.data.collate import collate_packed
16+
from forge.data import CROSS_ENTROPY_IGNORE_IDX
17+
from forge.data.collate import collate_packed, collate_padded
1718
from forge.data.datasets import HfIterableDataset
1819
from forge.data.datasets.packed import (
1920
_SUPPORTS_FLEX_ATTENTION,
@@ -995,3 +996,133 @@ def test_iter_restart_determinism(self, dataset_factory):
995996
pack2["document_ids"],
996997
msg=f"Pack {i}: document_ids mismatch between iterations",
997998
)
999+
1000+
1001+
class TestCollatePadded:
1002+
"""Test collate_padded function"""
1003+
1004+
def test_empty_batch(self):
1005+
"""Test collating an empty batch"""
1006+
result = collate_padded([])
1007+
assert result == {}
1008+
1009+
def test_single_sample(self):
1010+
"""Test collating a single sample"""
1011+
batch = [
1012+
{
1013+
"tokens": torch.tensor([1, 2, 3]),
1014+
"labels": torch.tensor([4, 5, 6]),
1015+
}
1016+
]
1017+
result = collate_padded(batch)
1018+
1019+
assert result["tokens"].shape == (1, 3)
1020+
assert result["labels"].shape == (1, 3)
1021+
torch.testing.assert_close(result["tokens"], torch.tensor([[1, 2, 3]]))
1022+
torch.testing.assert_close(result["labels"], torch.tensor([[4, 5, 6]]))
1023+
1024+
def test_equal_length_samples(self):
1025+
"""Test collating samples with equal lengths"""
1026+
batch = [
1027+
{
1028+
"tokens": torch.tensor([1, 2, 3]),
1029+
"labels": torch.tensor([4, 5, 6]),
1030+
},
1031+
{
1032+
"tokens": torch.tensor([7, 8, 9]),
1033+
"labels": torch.tensor([10, 11, 12]),
1034+
},
1035+
]
1036+
result = collate_padded(batch)
1037+
1038+
assert result["tokens"].shape == (2, 3)
1039+
assert result["labels"].shape == (2, 3)
1040+
torch.testing.assert_close(
1041+
result["tokens"], torch.tensor([[1, 2, 3], [7, 8, 9]])
1042+
)
1043+
torch.testing.assert_close(
1044+
result["labels"], torch.tensor([[4, 5, 6], [10, 11, 12]])
1045+
)
1046+
1047+
def test_padding_to_longest(self):
1048+
"""Test padding shorter sequences to the longest in batch"""
1049+
batch = [
1050+
{
1051+
"tokens": torch.tensor([1, 2]),
1052+
"labels": torch.tensor([3, 4]),
1053+
},
1054+
{
1055+
"tokens": torch.tensor([5, 6, 7, 8]),
1056+
"labels": torch.tensor([9, 10, 11, 12]),
1057+
},
1058+
{
1059+
"tokens": torch.tensor([13, 14, 15]),
1060+
"labels": torch.tensor([16, 17, 18]),
1061+
},
1062+
]
1063+
result = collate_padded(batch)
1064+
1065+
# All should be padded to length 4 (longest)
1066+
assert result["tokens"].shape == (3, 4)
1067+
assert result["labels"].shape == (3, 4)
1068+
1069+
# Check tokens padding (padded with 0)
1070+
torch.testing.assert_close(
1071+
result["tokens"],
1072+
torch.tensor([[1, 2, 0, 0], [5, 6, 7, 8], [13, 14, 15, 0]]),
1073+
)
1074+
1075+
# Check labels padding (padded with CROSS_ENTROPY_IGNORE_IDX)
1076+
torch.testing.assert_close(
1077+
result["labels"],
1078+
torch.tensor(
1079+
[
1080+
[3, 4, CROSS_ENTROPY_IGNORE_IDX, CROSS_ENTROPY_IGNORE_IDX],
1081+
[9, 10, 11, 12],
1082+
[16, 17, 18, CROSS_ENTROPY_IGNORE_IDX],
1083+
]
1084+
),
1085+
)
1086+
1087+
def test_non_tensor_fields_preserved(self):
1088+
"""Test that non-tensor fields are collected correctly"""
1089+
batch = [
1090+
{
1091+
"tokens": torch.tensor([1, 2]),
1092+
"labels": torch.tensor([3, 4]),
1093+
"metadata": "sample1",
1094+
},
1095+
{
1096+
"tokens": torch.tensor([5, 6, 7]),
1097+
"labels": torch.tensor([8, 9, 10]),
1098+
"metadata": "sample2",
1099+
},
1100+
]
1101+
result = collate_padded(batch)
1102+
1103+
assert "metadata" in result
1104+
assert result["metadata"] == ["sample1", "sample2"]
1105+
1106+
def test_metrics_flattened(self):
1107+
"""Test that metrics lists are flattened"""
1108+
batch = [
1109+
{
1110+
"tokens": torch.tensor([1, 2]),
1111+
"labels": torch.tensor([3, 4]),
1112+
"metrics": [
1113+
type("Metric", (), {"key": "loss", "value": 1.0})(),
1114+
type("Metric", (), {"key": "acc", "value": 0.9})(),
1115+
],
1116+
},
1117+
{
1118+
"tokens": torch.tensor([5, 6, 7]),
1119+
"labels": torch.tensor([8, 9, 10]),
1120+
"metrics": [type("Metric", (), {"key": "loss", "value": 2.0})()],
1121+
},
1122+
]
1123+
result = collate_padded(batch)
1124+
1125+
assert "metrics" in result
1126+
# Should be flattened from [[metric1, metric2], [metric3]] to [metric1, metric2, metric3]
1127+
assert len(result["metrics"]) == 3
1128+

0 commit comments

Comments
 (0)