Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion apps/sft/llama3_8b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ lr_scheduler:
warmup_steps: 200

training:
local_batch_size: 1
local_batch_size: 8
seq_len: 2048
max_norm: 1.0
steps: 1000
Expand Down
27 changes: 10 additions & 17 deletions apps/sft/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,13 @@
import math
import os
import sys
from functools import partial
from typing import Any

import torch

import torchtitan.experiments.forge.train_spec as forge_train_spec
from forge.controller import ForgeActor
from forge.data.collate import collate_packed
from forge.data.datasets.packed import PackedDataset, TextPacker
from forge.data.collate import collate_padded
from forge.data.datasets.sft_dataset import AlpacaToMessages, sft_iterable_dataset
from forge.data.tokenizer import HuggingFaceModelTokenizer
from forge.data.utils import StopAfterOneEpoch
Expand Down Expand Up @@ -97,6 +95,13 @@ def record_batch_metrics(self, data_metrics: list):

@endpoint
async def setup(self):
# Validate that compile is only used with flex attention
if self.job_config.training.compile:
raise ValueError(
"training.compile=True is not currently supported. "
"Compile is only supported with flex attention enabled, which requires PyTorch nightly. "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Objection to start a main issue tracking the nightly build?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good! But can we first nail down the different subtasks via the Google Doc I just shared? Then we can translate to a GI.

"Please set training.compile=false in your config."
)

# all ranks should record loss, except when PP=True. Then, only the last stage should record loss.
self.rank_should_record_loss = True
Expand Down Expand Up @@ -152,6 +157,7 @@ def setup_data(self, dataset_configs: list[dict]) -> StatefulDataLoader:
Raises:
ValueError: If multiple datasets provided (not yet supported)
"""

# TODO felipemello: Currently only support single dataset
if len(dataset_configs) > 1:
raise ValueError(
Expand Down Expand Up @@ -197,25 +203,12 @@ def setup_data(self, dataset_configs: list[dict]) -> StatefulDataLoader:
**dataset_config,
)

packer = TextPacker(padding_idx=0)
dataset = PackedDataset(
dataset=dataset,
packer=packer,
target_tokens_per_pack=self.job_config.training.seq_len, # TODO: get this from model
)

dataloader = StatefulDataLoader(
dataset=dataset,
batch_size=self.job_config.training.local_batch_size,
collate_fn=partial(
collate_packed, mask_fn=packer.create_block_mask, device=self.device
),
collate_fn=collate_padded,
)

# Ultimately we probably want something like this
# packer = build_packing_strategy(packing_config)
# dataset = build_dataset(dataset_config)
# dataloader = build_dataloader(dataloader_config, dataset, packer)
return dataloader

def forward_backward(
Expand Down
2 changes: 1 addition & 1 deletion apps/sft/qwen3_8b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ lr_scheduler:
warmup_steps: 200

training:
local_batch_size: 1
local_batch_size: 8
seq_len: 2048
max_norm: 1.0
steps: 1000
Expand Down
3 changes: 2 additions & 1 deletion src/forge/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from .collate import collate_packed
from .collate import collate_packed, collate_padded
from .metric_transform import DefaultDatasetMetricTransform, MetricTransform
from .utils import CROSS_ENTROPY_IGNORE_IDX

__all__ = [
"collate_packed",
"collate_padded",
"CROSS_ENTROPY_IGNORE_IDX",
"MetricTransform",
"DefaultDatasetMetricTransform",
Expand Down
66 changes: 66 additions & 0 deletions src/forge/data/collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,72 @@
from typing import Any, Callable

import torch
import torch.nn.functional as F

from forge.data.utils import CROSS_ENTROPY_IGNORE_IDX


def collate_padded(batch: list[dict[str, Any]]) -> dict[str, Any]:
"""
Collate function that pads sequences to the longest sample in the batch.

Handles any tensor keys by padding to the longest
sequence for that key. Uses 0 as default padding value, and
CROSS_ENTROPY_IGNORE_IDX (-100) for 'labels' keys.

Non-tensor fields are collected into lists. The 'metrics' field is
special-cased to be flattened (extended) rather than nested.

Args:
batch: List of samples, each containing tensor and non-tensor fields

Returns:
Batched dict with padded tensors and collected non-tensor fields

Raises:
ValueError: If all samples do not have the same keys
"""
if not batch:
return {}

# Verify all samples have the same keys
first_sample_keys = batch[0].keys()
for sample in batch:
if sample.keys() != first_sample_keys:
raise ValueError(
f"All samples must have the same keys. Expected {first_sample_keys}, got {sample.keys()}"
)

collated = {}

for key in first_sample_keys:
if isinstance(batch[0][key], torch.Tensor):
# Find max length for this tensor key
max_len = max(sample[key].size(0) for sample in batch)

# Determine padding value
pad_value = CROSS_ENTROPY_IGNORE_IDX if key == "labels" else 0

# Pad each sample to max_len
padded_tensors = []
for sample in batch:
seq_len = sample[key].size(0)
pad_len = max_len - seq_len
padded = F.pad(sample[key], (0, pad_len), value=pad_value)
padded_tensors.append(padded)

# Stack into batch
collated[key] = torch.stack(padded_tensors)
elif key == "metrics":
# Flatten metrics lists
collated[key] = []
for sample in batch:
collated[key].extend(sample[key])
else:
# Collect other non-tensor fields as lists
collated[key] = [sample[key] for sample in batch]

return collated


def collate_packed(
Expand Down
182 changes: 181 additions & 1 deletion tests/unit_tests/datasets/test_packed.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
import pytest
import torch

from forge.data.collate import collate_packed
from forge.data import CROSS_ENTROPY_IGNORE_IDX
from forge.data.collate import collate_packed, collate_padded
from forge.data.datasets import HfIterableDataset
from forge.data.datasets.packed import (
_SUPPORTS_FLEX_ATTENTION,
Expand Down Expand Up @@ -995,3 +996,182 @@ def test_iter_restart_determinism(self, dataset_factory):
pack2["document_ids"],
msg=f"Pack {i}: document_ids mismatch between iterations",
)


class TestCollatePadded:
"""Test collate_padded function"""

def test_empty_batch(self):
"""Test collating an empty batch"""
result = collate_padded([])
assert result == {}

def test_single_sample(self):
"""Test collating a single sample"""
batch = [
{
"tokens": torch.tensor([1, 2, 3]),
"labels": torch.tensor([4, 5, 6]),
}
]
result = collate_padded(batch)

assert result["tokens"].shape == (1, 3)
assert result["labels"].shape == (1, 3)
torch.testing.assert_close(result["tokens"], torch.tensor([[1, 2, 3]]))
torch.testing.assert_close(result["labels"], torch.tensor([[4, 5, 6]]))

def test_equal_length_samples(self):
"""Test collating samples with equal lengths"""
batch = [
{
"tokens": torch.tensor([1, 2, 3]),
"labels": torch.tensor([4, 5, 6]),
},
{
"tokens": torch.tensor([7, 8, 9]),
"labels": torch.tensor([10, 11, 12]),
},
]
result = collate_padded(batch)

assert result["tokens"].shape == (2, 3)
assert result["labels"].shape == (2, 3)
torch.testing.assert_close(
result["tokens"], torch.tensor([[1, 2, 3], [7, 8, 9]])
)
torch.testing.assert_close(
result["labels"], torch.tensor([[4, 5, 6], [10, 11, 12]])
)

def test_padding_to_longest(self):
"""Test padding shorter sequences to the longest in batch"""
batch = [
{
"tokens": torch.tensor([1, 2]),
"labels": torch.tensor([3, 4]),
},
{
"tokens": torch.tensor([5, 6, 7, 8]),
"labels": torch.tensor([9, 10, 11, 12]),
},
{
"tokens": torch.tensor([13, 14, 15]),
"labels": torch.tensor([16, 17, 18]),
},
]
result = collate_padded(batch)

# All should be padded to length 4 (longest)
assert result["tokens"].shape == (3, 4)
assert result["labels"].shape == (3, 4)

# Check tokens padding (padded with 0)
torch.testing.assert_close(
result["tokens"],
torch.tensor([[1, 2, 0, 0], [5, 6, 7, 8], [13, 14, 15, 0]]),
)

# Check labels padding (padded with CROSS_ENTROPY_IGNORE_IDX)
torch.testing.assert_close(
result["labels"],
torch.tensor(
[
[3, 4, CROSS_ENTROPY_IGNORE_IDX, CROSS_ENTROPY_IGNORE_IDX],
[9, 10, 11, 12],
[16, 17, 18, CROSS_ENTROPY_IGNORE_IDX],
]
),
)

def test_non_tensor_fields_preserved(self):
"""Test that non-tensor fields are collected correctly"""
batch = [
{
"tokens": torch.tensor([1, 2]),
"labels": torch.tensor([3, 4]),
"metadata": "sample1",
},
{
"tokens": torch.tensor([5, 6, 7]),
"labels": torch.tensor([8, 9, 10]),
"metadata": "sample2",
},
]
result = collate_padded(batch)

assert "metadata" in result
assert result["metadata"] == ["sample1", "sample2"]

def test_metrics_flattened(self):
"""Test that metrics lists are flattened"""
batch = [
{
"tokens": torch.tensor([1, 2]),
"labels": torch.tensor([3, 4]),
"metrics": [
type("Metric", (), {"key": "loss", "value": 1.0})(),
type("Metric", (), {"key": "acc", "value": 0.9})(),
],
},
{
"tokens": torch.tensor([5, 6, 7]),
"labels": torch.tensor([8, 9, 10]),
"metrics": [type("Metric", (), {"key": "loss", "value": 2.0})()],
},
]
result = collate_padded(batch)

assert "metrics" in result
# Should be flattened from [[metric1, metric2], [metric3]] to [metric1, metric2, metric3]
assert len(result["metrics"]) == 3

def test_different_keys_error(self):
"""Test that different keys across samples raises ValueError"""
batch = [
{"tokens": torch.tensor([1, 2]), "labels": torch.tensor([3, 4])},
{"tokens": torch.tensor([5, 6]), "other_key": torch.tensor([7, 8])},
]

with pytest.raises(ValueError, match="All samples must have the same keys"):
collate_padded(batch)

def test_generic_tensor_handling(self):
"""Test that any tensor field gets padded correctly"""
batch = [
{
"tokens": torch.tensor([1, 2]),
"labels": torch.tensor([3, 4]),
"custom_tensor": torch.tensor([100, 200, 300]),
},
{
"tokens": torch.tensor([5, 6, 7, 8]),
"labels": torch.tensor([9, 10, 11, 12]),
"custom_tensor": torch.tensor([400]),
},
]
result = collate_padded(batch)

# Tokens padded to length 4
assert result["tokens"].shape == (2, 4)
torch.testing.assert_close(
result["tokens"], torch.tensor([[1, 2, 0, 0], [5, 6, 7, 8]])
)

# Labels padded to length 4 with CROSS_ENTROPY_IGNORE_IDX
assert result["labels"].shape == (2, 4)
torch.testing.assert_close(
result["labels"],
torch.tensor(
[
[3, 4, CROSS_ENTROPY_IGNORE_IDX, CROSS_ENTROPY_IGNORE_IDX],
[9, 10, 11, 12],
]
),
)

# Custom tensor padded to length 3 with 0
assert result["custom_tensor"].shape == (2, 3)
torch.testing.assert_close(
result["custom_tensor"], torch.tensor([[100, 200, 300], [400, 0, 0]])
)
Loading