Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion apps/sft/llama3_8b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ lr_scheduler:
warmup_steps: 200

training:
local_batch_size: 1
local_batch_size: 8
seq_len: 2048
max_norm: 1.0
steps: 1000
Expand Down
27 changes: 10 additions & 17 deletions apps/sft/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,13 @@
import math
import os
import sys
from functools import partial
from typing import Any

import torch

import torchtitan.experiments.forge.train_spec as forge_train_spec
from forge.controller import ForgeActor
from forge.data.collate import collate_packed
from forge.data.datasets.packed import PackedDataset, TextPacker
from forge.data.collate import collate_padded
from forge.data.datasets.sft_dataset import AlpacaToMessages, sft_iterable_dataset
from forge.data.tokenizer import HuggingFaceModelTokenizer
from forge.data.utils import StopAfterOneEpoch
Expand Down Expand Up @@ -97,6 +95,13 @@ def record_batch_metrics(self, data_metrics: list):

@endpoint
async def setup(self):
# Validate that compile is only used with flex attention
if self.job_config.training.compile:
raise ValueError(
"training.compile=True is not currently supported. "
"Compile is only supported with flex attention enabled, which requires PyTorch nightly. "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Objection to start a main issue tracking the nightly build?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good! But can we first nail down the different subtasks via the Google Doc I just shared? Then we can translate to a GI.

"Please set training.compile=false in your config."
)

# all ranks should record loss, except when PP=True. Then, only the last stage should record loss.
self.rank_should_record_loss = True
Expand Down Expand Up @@ -152,6 +157,7 @@ def setup_data(self, dataset_configs: list[dict]) -> StatefulDataLoader:
Raises:
ValueError: If multiple datasets provided (not yet supported)
"""

# TODO felipemello: Currently only support single dataset
if len(dataset_configs) > 1:
raise ValueError(
Expand Down Expand Up @@ -197,25 +203,12 @@ def setup_data(self, dataset_configs: list[dict]) -> StatefulDataLoader:
**dataset_config,
)

packer = TextPacker(padding_idx=0)
dataset = PackedDataset(
dataset=dataset,
packer=packer,
target_tokens_per_pack=self.job_config.training.seq_len, # TODO: get this from model
)

dataloader = StatefulDataLoader(
dataset=dataset,
batch_size=self.job_config.training.local_batch_size,
collate_fn=partial(
collate_packed, mask_fn=packer.create_block_mask, device=self.device
),
collate_fn=collate_padded,
)

# Ultimately we probably want something like this
# packer = build_packing_strategy(packing_config)
# dataset = build_dataset(dataset_config)
# dataloader = build_dataloader(dataloader_config, dataset, packer)
return dataloader

def forward_backward(
Expand Down
2 changes: 1 addition & 1 deletion apps/sft/qwen3_8b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ lr_scheduler:
warmup_steps: 200

training:
local_batch_size: 1
local_batch_size: 8
seq_len: 2048
max_norm: 1.0
steps: 1000
Expand Down
3 changes: 2 additions & 1 deletion src/forge/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from .collate import collate_packed
from .collate import collate_packed, collate_padded
from .metric_transform import DefaultDatasetMetricTransform, MetricTransform
from .utils import CROSS_ENTROPY_IGNORE_IDX

__all__ = [
"collate_packed",
"collate_padded",
"CROSS_ENTROPY_IGNORE_IDX",
"MetricTransform",
"DefaultDatasetMetricTransform",
Expand Down
59 changes: 59 additions & 0 deletions src/forge/data/collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,65 @@
from typing import Any, Callable

import torch
import torch.nn.functional as F

from forge.data.utils import CROSS_ENTROPY_IGNORE_IDX


def collate_padded(batch: list[dict[str, Any]]) -> dict[str, Any]:
"""
Collate function that pads sequences to the longest sample in the batch.
Pads 'tokens' with 0 and 'labels' with CROSS_ENTROPY_IGNORE_IDX (-100).
Non-tensor fields (like metrics) are collected into lists and flattened
if all items are lists.
Comment on lines +19 to +21
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it common practice to assume tokens and labels are the keys for collate_padded?

Args:
batch: List of samples, each containing 'tokens' and 'labels' tensors
Returns:
Batched dict with padded tensors
"""
if not batch:
return {}

# Find max length in batch
max_len = max(sample["tokens"].size(0) for sample in batch)

# Initialize lists for batched tensors
tokens_list = []
labels_list = []

# Pad each sample to max_len
for sample in batch:
seq_len = sample["tokens"].size(0)
pad_len = max_len - seq_len

# Pad tokens with 0
padded_tokens = F.pad(sample["tokens"], (0, pad_len), value=0)
tokens_list.append(padded_tokens)

# Pad labels with CROSS_ENTROPY_IGNORE_IDX (-100)
padded_labels = F.pad(
sample["labels"], (0, pad_len), value=CROSS_ENTROPY_IGNORE_IDX
)
labels_list.append(padded_labels)

# Stack into batch
result = {
"tokens": torch.stack(tokens_list),
"labels": torch.stack(labels_list),
}

# Collect non-tensor fields (like metrics)
for key in batch[0].keys():
if key not in ["tokens", "labels"]:
result[key] = [sample[key] for sample in batch]
# Flatten if all are lists
if all(isinstance(item, list) for item in result[key]):
result[key] = [item for sublist in result[key] for item in sublist]
Comment on lines +64 to +66
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a common practice? Feels like unnecessary operation / tribal knowledge


return result


def collate_packed(
Expand Down
132 changes: 131 additions & 1 deletion tests/unit_tests/datasets/test_packed.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
import pytest
import torch

from forge.data.collate import collate_packed
from forge.data import CROSS_ENTROPY_IGNORE_IDX
from forge.data.collate import collate_packed, collate_padded
from forge.data.datasets import HfIterableDataset
from forge.data.datasets.packed import (
_SUPPORTS_FLEX_ATTENTION,
Expand Down Expand Up @@ -995,3 +996,132 @@ def test_iter_restart_determinism(self, dataset_factory):
pack2["document_ids"],
msg=f"Pack {i}: document_ids mismatch between iterations",
)


class TestCollatePadded:
"""Test collate_padded function"""

def test_empty_batch(self):
"""Test collating an empty batch"""
result = collate_padded([])
assert result == {}

def test_single_sample(self):
"""Test collating a single sample"""
batch = [
{
"tokens": torch.tensor([1, 2, 3]),
"labels": torch.tensor([4, 5, 6]),
}
]
result = collate_padded(batch)

assert result["tokens"].shape == (1, 3)
assert result["labels"].shape == (1, 3)
torch.testing.assert_close(result["tokens"], torch.tensor([[1, 2, 3]]))
torch.testing.assert_close(result["labels"], torch.tensor([[4, 5, 6]]))

def test_equal_length_samples(self):
"""Test collating samples with equal lengths"""
batch = [
{
"tokens": torch.tensor([1, 2, 3]),
"labels": torch.tensor([4, 5, 6]),
},
{
"tokens": torch.tensor([7, 8, 9]),
"labels": torch.tensor([10, 11, 12]),
},
]
result = collate_padded(batch)

assert result["tokens"].shape == (2, 3)
assert result["labels"].shape == (2, 3)
torch.testing.assert_close(
result["tokens"], torch.tensor([[1, 2, 3], [7, 8, 9]])
)
torch.testing.assert_close(
result["labels"], torch.tensor([[4, 5, 6], [10, 11, 12]])
)

def test_padding_to_longest(self):
"""Test padding shorter sequences to the longest in batch"""
batch = [
{
"tokens": torch.tensor([1, 2]),
"labels": torch.tensor([3, 4]),
},
{
"tokens": torch.tensor([5, 6, 7, 8]),
"labels": torch.tensor([9, 10, 11, 12]),
},
{
"tokens": torch.tensor([13, 14, 15]),
"labels": torch.tensor([16, 17, 18]),
},
]
result = collate_padded(batch)

# All should be padded to length 4 (longest)
assert result["tokens"].shape == (3, 4)
assert result["labels"].shape == (3, 4)

# Check tokens padding (padded with 0)
torch.testing.assert_close(
result["tokens"],
torch.tensor([[1, 2, 0, 0], [5, 6, 7, 8], [13, 14, 15, 0]]),
)

# Check labels padding (padded with CROSS_ENTROPY_IGNORE_IDX)
torch.testing.assert_close(
result["labels"],
torch.tensor(
[
[3, 4, CROSS_ENTROPY_IGNORE_IDX, CROSS_ENTROPY_IGNORE_IDX],
[9, 10, 11, 12],
[16, 17, 18, CROSS_ENTROPY_IGNORE_IDX],
]
),
)

def test_non_tensor_fields_preserved(self):
"""Test that non-tensor fields are collected correctly"""
batch = [
{
"tokens": torch.tensor([1, 2]),
"labels": torch.tensor([3, 4]),
"metadata": "sample1",
},
{
"tokens": torch.tensor([5, 6, 7]),
"labels": torch.tensor([8, 9, 10]),
"metadata": "sample2",
},
]
result = collate_padded(batch)

assert "metadata" in result
assert result["metadata"] == ["sample1", "sample2"]

def test_metrics_flattened(self):
"""Test that metrics lists are flattened"""
batch = [
{
"tokens": torch.tensor([1, 2]),
"labels": torch.tensor([3, 4]),
"metrics": [
type("Metric", (), {"key": "loss", "value": 1.0})(),
type("Metric", (), {"key": "acc", "value": 0.9})(),
],
},
{
"tokens": torch.tensor([5, 6, 7]),
"labels": torch.tensor([8, 9, 10]),
"metrics": [type("Metric", (), {"key": "loss", "value": 2.0})()],
},
]
result = collate_padded(batch)

assert "metrics" in result
# Should be flattened from [[metric1, metric2], [metric3]] to [metric1, metric2, metric3]
assert len(result["metrics"]) == 3
Loading