Skip to content

Commit 700b2f5

Browse files
authored
Temporarily use no packing in SFT (#614)
1 parent 55ec276 commit 700b2f5

File tree

6 files changed

+261
-21
lines changed

6 files changed

+261
-21
lines changed

apps/sft/llama3_8b.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ lr_scheduler:
2727
warmup_steps: 200
2828

2929
training:
30-
local_batch_size: 1
30+
local_batch_size: 8
3131
seq_len: 2048
3232
max_norm: 1.0
3333
steps: 1000

apps/sft/main.py

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,13 @@
1616
import math
1717
import os
1818
import sys
19-
from functools import partial
2019
from typing import Any
2120

2221
import torch
2322

2423
import torchtitan.experiments.forge.train_spec as forge_train_spec
2524
from forge.controller import ForgeActor
26-
from forge.data.collate import collate_packed
27-
from forge.data.datasets.packed import PackedDataset, TextPacker
25+
from forge.data.collate import collate_padded
2826
from forge.data.datasets.sft_dataset import AlpacaToMessages, sft_iterable_dataset
2927
from forge.data.tokenizer import HuggingFaceModelTokenizer
3028
from forge.data.utils import StopAfterOneEpoch
@@ -97,6 +95,13 @@ def record_batch_metrics(self, data_metrics: list):
9795

9896
@endpoint
9997
async def setup(self):
98+
# Validate that compile is only used with flex attention
99+
if self.job_config.training.compile:
100+
raise ValueError(
101+
"training.compile=True is not currently supported. "
102+
"Compile is only supported with flex attention enabled, which requires PyTorch nightly. "
103+
"Please set training.compile=false in your config."
104+
)
100105

101106
# all ranks should record loss, except when PP=True. Then, only the last stage should record loss.
102107
self.rank_should_record_loss = True
@@ -152,6 +157,7 @@ def setup_data(self, dataset_configs: list[dict]) -> StatefulDataLoader:
152157
Raises:
153158
ValueError: If multiple datasets provided (not yet supported)
154159
"""
160+
155161
# TODO felipemello: Currently only support single dataset
156162
if len(dataset_configs) > 1:
157163
raise ValueError(
@@ -197,25 +203,12 @@ def setup_data(self, dataset_configs: list[dict]) -> StatefulDataLoader:
197203
**dataset_config,
198204
)
199205

200-
packer = TextPacker(padding_idx=0)
201-
dataset = PackedDataset(
202-
dataset=dataset,
203-
packer=packer,
204-
target_tokens_per_pack=self.job_config.training.seq_len, # TODO: get this from model
205-
)
206-
207206
dataloader = StatefulDataLoader(
208207
dataset=dataset,
209208
batch_size=self.job_config.training.local_batch_size,
210-
collate_fn=partial(
211-
collate_packed, mask_fn=packer.create_block_mask, device=self.device
212-
),
209+
collate_fn=collate_padded,
213210
)
214211

215-
# Ultimately we probably want something like this
216-
# packer = build_packing_strategy(packing_config)
217-
# dataset = build_dataset(dataset_config)
218-
# dataloader = build_dataloader(dataloader_config, dataset, packer)
219212
return dataloader
220213

221214
def forward_backward(

apps/sft/qwen3_8b.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ lr_scheduler:
2626
warmup_steps: 200
2727

2828
training:
29-
local_batch_size: 1
29+
local_batch_size: 8
3030
seq_len: 2048
3131
max_norm: 1.0
3232
steps: 1000

src/forge/data/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@
33
#
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
6-
from .collate import collate_packed
6+
from .collate import collate_packed, collate_padded
77
from .metric_transform import DefaultDatasetMetricTransform, MetricTransform
88
from .utils import CROSS_ENTROPY_IGNORE_IDX
99

1010
__all__ = [
1111
"collate_packed",
12+
"collate_padded",
1213
"CROSS_ENTROPY_IGNORE_IDX",
1314
"MetricTransform",
1415
"DefaultDatasetMetricTransform",

src/forge/data/collate.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,72 @@
77
from typing import Any, Callable
88

99
import torch
10+
import torch.nn.functional as F
11+
12+
from forge.data.utils import CROSS_ENTROPY_IGNORE_IDX
13+
14+
15+
def collate_padded(batch: list[dict[str, Any]]) -> dict[str, Any]:
16+
"""
17+
Collate function that pads sequences to the longest sample in the batch.
18+
19+
Handles any tensor keys by padding to the longest
20+
sequence for that key. Uses 0 as default padding value, and
21+
CROSS_ENTROPY_IGNORE_IDX (-100) for 'labels' keys.
22+
23+
Non-tensor fields are collected into lists. The 'metrics' field is
24+
special-cased to be flattened (extended) rather than nested.
25+
26+
Args:
27+
batch: List of samples, each containing tensor and non-tensor fields
28+
29+
Returns:
30+
Batched dict with padded tensors and collected non-tensor fields
31+
32+
Raises:
33+
ValueError: If all samples do not have the same keys
34+
"""
35+
if not batch:
36+
return {}
37+
38+
# Verify all samples have the same keys
39+
first_sample_keys = batch[0].keys()
40+
for sample in batch:
41+
if sample.keys() != first_sample_keys:
42+
raise ValueError(
43+
f"All samples must have the same keys. Expected {first_sample_keys}, got {sample.keys()}"
44+
)
45+
46+
collated = {}
47+
48+
for key in first_sample_keys:
49+
if isinstance(batch[0][key], torch.Tensor):
50+
# Find max length for this tensor key
51+
max_len = max(sample[key].size(0) for sample in batch)
52+
53+
# Determine padding value
54+
pad_value = CROSS_ENTROPY_IGNORE_IDX if key == "labels" else 0
55+
56+
# Pad each sample to max_len
57+
padded_tensors = []
58+
for sample in batch:
59+
seq_len = sample[key].size(0)
60+
pad_len = max_len - seq_len
61+
padded = F.pad(sample[key], (0, pad_len), value=pad_value)
62+
padded_tensors.append(padded)
63+
64+
# Stack into batch
65+
collated[key] = torch.stack(padded_tensors)
66+
elif key == "metrics":
67+
# Flatten metrics lists
68+
collated[key] = []
69+
for sample in batch:
70+
collated[key].extend(sample[key])
71+
else:
72+
# Collect other non-tensor fields as lists
73+
collated[key] = [sample[key] for sample in batch]
74+
75+
return collated
1076

1177

1278
def collate_packed(

tests/unit_tests/datasets/test_packed.py

Lines changed: 181 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
import pytest
1414
import torch
1515

16-
from forge.data.collate import collate_packed
16+
from forge.data import CROSS_ENTROPY_IGNORE_IDX
17+
from forge.data.collate import collate_packed, collate_padded
1718
from forge.data.datasets import HfIterableDataset
1819
from forge.data.datasets.packed import (
1920
_SUPPORTS_FLEX_ATTENTION,
@@ -995,3 +996,182 @@ def test_iter_restart_determinism(self, dataset_factory):
995996
pack2["document_ids"],
996997
msg=f"Pack {i}: document_ids mismatch between iterations",
997998
)
999+
1000+
1001+
class TestCollatePadded:
1002+
"""Test collate_padded function"""
1003+
1004+
def test_empty_batch(self):
1005+
"""Test collating an empty batch"""
1006+
result = collate_padded([])
1007+
assert result == {}
1008+
1009+
def test_single_sample(self):
1010+
"""Test collating a single sample"""
1011+
batch = [
1012+
{
1013+
"tokens": torch.tensor([1, 2, 3]),
1014+
"labels": torch.tensor([4, 5, 6]),
1015+
}
1016+
]
1017+
result = collate_padded(batch)
1018+
1019+
assert result["tokens"].shape == (1, 3)
1020+
assert result["labels"].shape == (1, 3)
1021+
torch.testing.assert_close(result["tokens"], torch.tensor([[1, 2, 3]]))
1022+
torch.testing.assert_close(result["labels"], torch.tensor([[4, 5, 6]]))
1023+
1024+
def test_equal_length_samples(self):
1025+
"""Test collating samples with equal lengths"""
1026+
batch = [
1027+
{
1028+
"tokens": torch.tensor([1, 2, 3]),
1029+
"labels": torch.tensor([4, 5, 6]),
1030+
},
1031+
{
1032+
"tokens": torch.tensor([7, 8, 9]),
1033+
"labels": torch.tensor([10, 11, 12]),
1034+
},
1035+
]
1036+
result = collate_padded(batch)
1037+
1038+
assert result["tokens"].shape == (2, 3)
1039+
assert result["labels"].shape == (2, 3)
1040+
torch.testing.assert_close(
1041+
result["tokens"], torch.tensor([[1, 2, 3], [7, 8, 9]])
1042+
)
1043+
torch.testing.assert_close(
1044+
result["labels"], torch.tensor([[4, 5, 6], [10, 11, 12]])
1045+
)
1046+
1047+
def test_padding_to_longest(self):
1048+
"""Test padding shorter sequences to the longest in batch"""
1049+
batch = [
1050+
{
1051+
"tokens": torch.tensor([1, 2]),
1052+
"labels": torch.tensor([3, 4]),
1053+
},
1054+
{
1055+
"tokens": torch.tensor([5, 6, 7, 8]),
1056+
"labels": torch.tensor([9, 10, 11, 12]),
1057+
},
1058+
{
1059+
"tokens": torch.tensor([13, 14, 15]),
1060+
"labels": torch.tensor([16, 17, 18]),
1061+
},
1062+
]
1063+
result = collate_padded(batch)
1064+
1065+
# All should be padded to length 4 (longest)
1066+
assert result["tokens"].shape == (3, 4)
1067+
assert result["labels"].shape == (3, 4)
1068+
1069+
# Check tokens padding (padded with 0)
1070+
torch.testing.assert_close(
1071+
result["tokens"],
1072+
torch.tensor([[1, 2, 0, 0], [5, 6, 7, 8], [13, 14, 15, 0]]),
1073+
)
1074+
1075+
# Check labels padding (padded with CROSS_ENTROPY_IGNORE_IDX)
1076+
torch.testing.assert_close(
1077+
result["labels"],
1078+
torch.tensor(
1079+
[
1080+
[3, 4, CROSS_ENTROPY_IGNORE_IDX, CROSS_ENTROPY_IGNORE_IDX],
1081+
[9, 10, 11, 12],
1082+
[16, 17, 18, CROSS_ENTROPY_IGNORE_IDX],
1083+
]
1084+
),
1085+
)
1086+
1087+
def test_non_tensor_fields_preserved(self):
1088+
"""Test that non-tensor fields are collected correctly"""
1089+
batch = [
1090+
{
1091+
"tokens": torch.tensor([1, 2]),
1092+
"labels": torch.tensor([3, 4]),
1093+
"metadata": "sample1",
1094+
},
1095+
{
1096+
"tokens": torch.tensor([5, 6, 7]),
1097+
"labels": torch.tensor([8, 9, 10]),
1098+
"metadata": "sample2",
1099+
},
1100+
]
1101+
result = collate_padded(batch)
1102+
1103+
assert "metadata" in result
1104+
assert result["metadata"] == ["sample1", "sample2"]
1105+
1106+
def test_metrics_flattened(self):
1107+
"""Test that metrics lists are flattened"""
1108+
batch = [
1109+
{
1110+
"tokens": torch.tensor([1, 2]),
1111+
"labels": torch.tensor([3, 4]),
1112+
"metrics": [
1113+
type("Metric", (), {"key": "loss", "value": 1.0})(),
1114+
type("Metric", (), {"key": "acc", "value": 0.9})(),
1115+
],
1116+
},
1117+
{
1118+
"tokens": torch.tensor([5, 6, 7]),
1119+
"labels": torch.tensor([8, 9, 10]),
1120+
"metrics": [type("Metric", (), {"key": "loss", "value": 2.0})()],
1121+
},
1122+
]
1123+
result = collate_padded(batch)
1124+
1125+
assert "metrics" in result
1126+
# Should be flattened from [[metric1, metric2], [metric3]] to [metric1, metric2, metric3]
1127+
assert len(result["metrics"]) == 3
1128+
1129+
def test_different_keys_error(self):
1130+
"""Test that different keys across samples raises ValueError"""
1131+
batch = [
1132+
{"tokens": torch.tensor([1, 2]), "labels": torch.tensor([3, 4])},
1133+
{"tokens": torch.tensor([5, 6]), "other_key": torch.tensor([7, 8])},
1134+
]
1135+
1136+
with pytest.raises(ValueError, match="All samples must have the same keys"):
1137+
collate_padded(batch)
1138+
1139+
def test_generic_tensor_handling(self):
1140+
"""Test that any tensor field gets padded correctly"""
1141+
batch = [
1142+
{
1143+
"tokens": torch.tensor([1, 2]),
1144+
"labels": torch.tensor([3, 4]),
1145+
"custom_tensor": torch.tensor([100, 200, 300]),
1146+
},
1147+
{
1148+
"tokens": torch.tensor([5, 6, 7, 8]),
1149+
"labels": torch.tensor([9, 10, 11, 12]),
1150+
"custom_tensor": torch.tensor([400]),
1151+
},
1152+
]
1153+
result = collate_padded(batch)
1154+
1155+
# Tokens padded to length 4
1156+
assert result["tokens"].shape == (2, 4)
1157+
torch.testing.assert_close(
1158+
result["tokens"], torch.tensor([[1, 2, 0, 0], [5, 6, 7, 8]])
1159+
)
1160+
1161+
# Labels padded to length 4 with CROSS_ENTROPY_IGNORE_IDX
1162+
assert result["labels"].shape == (2, 4)
1163+
torch.testing.assert_close(
1164+
result["labels"],
1165+
torch.tensor(
1166+
[
1167+
[3, 4, CROSS_ENTROPY_IGNORE_IDX, CROSS_ENTROPY_IGNORE_IDX],
1168+
[9, 10, 11, 12],
1169+
]
1170+
),
1171+
)
1172+
1173+
# Custom tensor padded to length 3 with 0
1174+
assert result["custom_tensor"].shape == (2, 3)
1175+
torch.testing.assert_close(
1176+
result["custom_tensor"], torch.tensor([[100, 200, 300], [400, 0, 0]])
1177+
)

0 commit comments

Comments
 (0)