Skip to content
Open
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
901723f
rfc
Jun 4, 2025
ab02d75
Revert "rfc"
Jun 4, 2025
1c0117b
Merge branch 'main' of https://github.com/pytorch/torchtune
Jun 6, 2025
1cc6946
Merge branch 'main' of https://github.com/pytorch/torchtune
Jun 11, 2025
2a2efa2
add packed functions
felipemello1 Jun 11, 2025
6e14b06
enable on full recipe
felipemello1 Jun 11, 2025
f9db469
fix imports + formatting
Jun 11, 2025
ff6fdbe
add max_steps_per_epoch requirement
felipemello1 Jun 11, 2025
5e447ab
address blockers
Jun 11, 2025
7a1dfa5
Merge branch 'main' into online_packing
Jun 11, 2025
1ffd459
Merge branch 'main' into online_packing
Jun 11, 2025
13cda28
small fixes
felipemello1 Jun 12, 2025
d26769c
add md doc
felipemello1 Jun 12, 2025
20cfa80
Merge remote-tracking branch 'refs/remotes/origin/online_packing' int…
felipemello1 Jun 12, 2025
59b8cab
update comments
felipemello1 Jun 12, 2025
5d7d496
update comments
felipemello1 Jun 12, 2025
e193926
update comment
felipemello1 Jun 12, 2025
40d79f4
update comment
felipemello1 Jun 12, 2025
3cab533
first commit
felipemello1 Jun 25, 2025
2212b19
update tests
felipemello1 Jun 25, 2025
4345832
Merge remote-tracking branch 'joecummings/impl-step-based-ckpt' into …
felipemello1 Jun 25, 2025
2eb68b6
linter
Jun 25, 2025
2e51e04
tests pass
Jun 25, 2025
93fa743
it works
Jun 26, 2025
aa9e6f4
remove code
Jun 26, 2025
a5e7234
Merge branch 'iterable_dataset_final' into online_packing
Jun 26, 2025
55be775
adjust pack to have metrics
Jun 26, 2025
382c4e9
remove comment
Jun 26, 2025
5b188ed
update metrics to use handlers
felipemello1 Jul 2, 2025
2eab08d
remove file after refactoring
felipemello1 Jul 2, 2025
58491f1
add distributed tsts
felipemello1 Jul 2, 2025
da7245d
Merge branch 'iterable_dataset_final' of github.com:felipemello1/torc…
Jul 2, 2025
96424d0
tests pass
Jul 2, 2025
853147b
optimize SFTOutputTransform
Jul 2, 2025
96bc317
use ds.sampling_weight
felipemello1 Jul 2, 2025
3c9d161
add sampling log to interlead dataset
felipemello1 Jul 2, 2025
4804663
fix nested interleave
felipemello1 Jul 3, 2025
2fe4b40
changes to TuneIterableDataset
felipemello1 Jul 3, 2025
72211c9
add IterableDataset back
Jul 3, 2025
b350ac7
nested interleaved + dataset.info
felipemello1 Jul 6, 2025
f9a1aec
nits hf_iterable
felipemello1 Jul 6, 2025
f7a3aa7
update readme
felipemello1 Jul 6, 2025
17878bf
make metric dataset name explicit
felipemello1 Jul 6, 2025
101e96e
update recipe to share log freq + validagtion msg
felipemello1 Jul 6, 2025
1b3f3fc
update interleaved tests to do nesting
Jul 6, 2025
fac3fd5
lint
Jul 6, 2025
29ba1cb
error if duplicated metric name
Jul 7, 2025
f89eefe
improve docs
Jul 7, 2025
de942bf
Merge branch 'iterable_dataset_final' into online_packing
felipemello1 Jul 7, 2025
d6680b7
rename from strategy to packer
felipemello1 Jul 7, 2025
d3be015
tensors instead of lists
felipemello1 Jul 7, 2025
c8bfbb2
tests
felipemello1 Jul 7, 2025
fd41842
docs
felipemello1 Jul 7, 2025
734128e
tests + lint pass
Jul 7, 2025
23bd9fb
test collate + dataloader
felipemello1 Jul 7, 2025
fb7b9aa
clean up
Jul 7, 2025
4c505e0
improve packed testing
Jul 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions planning/ontheflypacking.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
### What:
Packing is the process of putting together samples until a certain target size is reached. This is done to reduce the number of padding tokens in a batch. To avoid contamination between samples, we use a document-level causal mask. To make it faster, we use flex attention to handle the special mask.

Example:
```python
# The current pack with one sample
pack = {"tokens": [1, 2], "labels": [3, 4], "document_ids": [0, 0], "input_pos": [0, 1]}

# The next sample to be added
sample = {"tokens": [5, 6], "labels": [7, 8]}

# After adding the sample
added_docs = add_sample_to_pack(pack, sample, next_doc_id=1)
print(pack)
>>> {"tokens": [1, 2, 5, 6],
"labels": [3, 4, 7, 8],
"document_ids": [0, 0, 1, 1],
"input_pos": [0, 1, 0, 1]}

create_block_causal_mask(document_ids)
>>> [
[1, 0, 0, 0],
[1, 1, 0, 0],
[0, 0, 1, 0],
[0, 0, 1, 1],
]
```

### Goal:
0) Make packing a first-class citizen in TorchTune, available for all sorts of models and recipes.

### Context:
1) We currently have map-style packing. We pre-process the dataset before training, which is not scalable.
2) Packing is only present for SFT + text data. There is no contract for how to extend it to multimodal, DPO, etc.
3) Collate function has to be aware of packing logic. This is currently hardcoded in the recipe with if/else.

### Solution:
4) Implement a new on-the-fly packing that takes any iterable dataset as input;
5) Packing contract consists of
i) a `PackingStrategy` that defines a) how to pack and b) the **_mask_mod** used for flex attention;
ii) a `IterablePackedDataset` that takes any a) `PackingStrategy`, b) **iterable dataset** as input and yields packed samples;
iii) a `packed_collate_fn` that takes the batch of packed samples and a **mask_fn** (e.g. `strategy.create_block_mask`) to generate the attention mask on the fly.
To define a new packing strategy, the user only needs to implement the `PackingStrategy` class.

### Implementation:
6) Updated `full_finetune_distributed.py` to use `IterablePackedDataset` when packing is enabled. There are challenges related to iterable datasets and this will be tackled in a separate iterable dataset PR. Changes made were to enable it to run for this RFC.

### Not in this PR:
7) **Logging**: Since we cannot do len(iterable_dataset), we need to add proper logging/metadata to assist users in understanding how far along they are on each dataset and metrics regarding the samples (avg num tokens, avg num samples / pack, etc.)
8) **Packing-aware Loss**: For SFT, the same loss works for map-style and packing. This is not the case for DPO/GRPO, which would need different masking. Future work will have to handle how to associate packing with a loss that supports it.
9) **Packing-aware metrics**: Advanced metrics, such as logprob per sample, would require to be aware of packing;
8 changes: 6 additions & 2 deletions recipes/configs/llama3_2/3B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,21 @@ output_dir: /tmp/torchtune/llama3_2_3B/full # /tmp may be deleted by your system
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
path: /tmp/Llama-3.2-3B-Instruct/original/tokenizer.model
max_seq_len: null
max_seq_len: 4096

# Dataset and Sampler
dataset:
_component_: torchtune.datasets.alpaca_cleaned_dataset
packed: False # True increases speed
split: train[:95%]
seed: null
shuffle: True
batch_size: 4

# On-the-fly packing strategy
# Set packing_strategy: null to disable packing
packing_strategy:
_component_: torchtune.datasets.TextPackingStrategy

# Validation
run_val_every_n_steps: null # Change to an integer to enable validation every N steps
dataset_val:
Expand Down
115 changes: 85 additions & 30 deletions recipes/full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@
from torch.distributed.tensor import DTensor
from torch.distributed.tensor.parallel import parallelize_module
from torch.optim import Optimizer
from torch.utils.data import IterableDataset
from torchao.float8 import precompute_float8_dynamic_scale_for_fsdp
from torchdata.stateful_dataloader import StatefulDataLoader
from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler
from torchtune import config, modules, training, utils
from torchtune.config._utils import _get_component_from_path
from torchtune.data import padded_collate_packed
from torchtune.datasets import ConcatDataset
from torchtune.datasets import ConcatDataset, IterablePackedDataset
from torchtune.modules.embedding_utils import resize_token_embeddings
from torchtune.modules.loss import SFTLoss
from torchtune.modules.moe import utils as moe_utils
Expand Down Expand Up @@ -432,12 +432,19 @@ def setup(self, cfg: DictConfig) -> None:

# sampler and dataloader depend on the tokenizer and loss_fn and should be
# setup after both of these are initialized
collate_name = cfg.get("collate_fn", "torchtune.data.padded_collate_sft")
collate_name = cfg.get("collate_fn", None)
if collate_name is None:
if cfg.get("packing_strategy") is not None:
collate_name = "torchtune.data.collate_packed"
else:
collate_name = "torchtune.data.padded_collate_sft"

self._dataloader = self._setup_data(
cfg_dataset=cfg.dataset,
shuffle=cfg.shuffle,
batch_size=cfg.batch_size,
collate_fn=collate_name,
cfg_packing_strategy=cfg.get("packing_strategy", None),
)

# Setup validation dataloader if validation dataset is provided
Expand All @@ -448,6 +455,7 @@ def setup(self, cfg: DictConfig) -> None:
cfg_dataset=cfg.dataset_val,
batch_size=batch_size_val,
collate_fn=collate_name,
cfg_packing_strategy=cfg.get("packing_strategy", None),
shuffle=False,
)

Expand All @@ -458,14 +466,23 @@ def setup(self, cfg: DictConfig) -> None:
# by the dataloader, the max_steps_per_epoch param set by the user and the
# gradient_accumulation_steps param. This value is used for logging and tracking
# training state. The computation should happen after the dataloader has been setup
self._steps_per_epoch = (
len(self._dataloader) // self._gradient_accumulation_steps
)
if (
self.max_steps_per_epoch is not None
and self.max_steps_per_epoch < self._steps_per_epoch
):

# NOTE: Hack to get it running. needs to be properly addressed.
if isinstance(self._dataloader.dataset, IterableDataset):
if self.max_steps_per_epoch is None:
raise ValueError(
"max_steps_per_epoch must be specified for iterable datasets."
)
self._steps_per_epoch = self.max_steps_per_epoch
else:
self._steps_per_epoch = (
len(self._dataloader) // self._gradient_accumulation_steps
)
if (
self.max_steps_per_epoch is not None
and self.max_steps_per_epoch < self._steps_per_epoch
):
self._steps_per_epoch = self.max_steps_per_epoch
self.global_step = self.epochs_run * self._steps_per_epoch

# Setup lr scheduler
Expand Down Expand Up @@ -775,47 +792,84 @@ def _setup_data(
shuffle: bool,
batch_size: int,
collate_fn: str,
cfg_packing_strategy: Optional[DictConfig] = None,
dataloader_state_dict: Optional[dict[str, Any]] = None,
) -> StatefulDataLoader:
"""
All data related setup happens here. This recipe currently supports only
map-style datasets. If a state_dict is provided (meaning we are resuming a training run),
it is loaded into the dataloader.
"""
# 1. Instantiate the base map-style dataset (to be replaced with IterableDataset)
if isinstance(cfg_dataset, ListConfig):
datasets = [
config.instantiate(single_cfg_dataset, self._tokenizer)
for single_cfg_dataset in cfg_dataset
]
ds = ConcatDataset(datasets=datasets)
packed = getattr(ds, "packed", False)
else:
ds = config.instantiate(cfg_dataset, self._tokenizer)
packed = cfg_dataset.get("packed", False)

# Instantiate collate_fn
if "left_pad_sequence" in collate_fn:
raise RuntimeError("left_pad_sequence collator is only for inference.")
collate_fn = _get_component_from_path(collate_fn)

sampler = StatefulDistributedSampler(
ds, num_replicas=self.dp_degree, rank=self.dp_rank, shuffle=shuffle, seed=0
)

# 2. Set up packing
if cfg_packing_strategy:
if self._is_rank_zero:
self._logger.info("Using IterablePackedDataset for on-the-fly packing.")

packing_strategy = config.instantiate(
cfg_packing_strategy,
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index,
)

# NOTE: This is a temporary hack to make map-style dataset
# compatible with IterablePackedDataset
# ------------------------------------------------------------
class _SamplerWrapper(IterableDataset):
def __init__(self, data, sampler):
self._data = data
self._sampler = sampler

def __iter__(self):
for i in self._sampler:
yield self._data[i]

iterable_ds = _SamplerWrapper(ds, sampler)

# Sampler must be None for iterable datasets
sampler = None
# ------------------------------------------------------------

final_ds = IterablePackedDataset(
dataset=iterable_ds,
strategy=packing_strategy,
target_tokens_per_pack=self._tokenizer.max_seq_len,
)

collate_callable = partial(
_get_component_from_path(collate_fn),
mask_fn=packing_strategy.create_block_mask,
device=self._device,
)
else: # Fallback for non-packed datasets

final_ds = ds

collate_callable = partial(
_get_component_from_path(collate_fn),
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index,
pad_to_multiple_of=self.parallel_dims.min_seq_len_divisor,
)

dataloader = StatefulDataLoader(
dataset=ds,
dataset=final_ds,
batch_size=batch_size,
sampler=sampler,
collate_fn=(
partial(
collate_fn,
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index,
pad_to_multiple_of=self.parallel_dims.min_seq_len_divisor,
)
if not packed
else padded_collate_packed
),
# dropping last avoids shape issues with compile + flex attention
collate_fn=collate_callable,
drop_last=True,
)

Expand Down Expand Up @@ -910,7 +964,8 @@ def train(self) -> None:
# self.epochs_run should be non-zero when we're resuming from a checkpoint
for curr_epoch in range(self.epochs_run, self.total_epochs):
pbar = tqdm(total=self._steps_per_epoch, disable=not self._is_rank_zero)
self._dataloader.sampler.set_epoch(curr_epoch)
# NOTE: Temporary hack to make it work with the new packing strategy
self._dataloader.dataset.dataset._sampler.set_epoch(curr_epoch)
for idx, batch in enumerate(self._dataloader):
# Start tracking CUDA memory for active steps for just the first epoch
if (
Expand Down
2 changes: 2 additions & 0 deletions torchtune/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

from torchtune.data._collate import (
collate_packed,
left_pad_sequence,
padded_collate,
padded_collate_dpo,
Expand Down Expand Up @@ -59,5 +60,6 @@
"padded_collate",
"padded_collate_tiled_images_and_mask",
"padded_collate_packed",
"collate_packed",
"load_image",
]
30 changes: 30 additions & 0 deletions torchtune/data/_collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,36 @@
from torchtune.modules.attention_utils import packed_block_causal_mask


def collate_packed(
batch: list[dict[str, torch.Tensor]], mask_fn: callable, device: str
) -> dict[str, torch.Tensor]:
"""
Generic collate function for packed samples from an IterablePackedDataset.
This function handles tensor stacking and delegates attention mask creation
to a provided `mask_fn`.
"""
if not batch:
return {}

# Assumes all samples in the batch have the same keys, which are all tensors.
keys_to_stack = batch[0].keys()
collated = {}
for key in keys_to_stack:
if isinstance(batch[0][key], torch.Tensor):
collated[key] = torch.stack([sample[key] for sample in batch], dim=0)
else:
# TODO: Remove? i dont see a situation where it would not be a tensor.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree

collated[key] = [sample[key] for sample in batch]

# Delegate mask creation to the provided specialized function
# TODO: investigate the need for device here. Currently we hardcode it in utilities to cuda.
# shouldnt we just send to device later?
collated["mask"] = mask_fn(collated["document_ids"], device=device)

return collated


def left_pad_sequence(
sequences: list[torch.Tensor],
batch_first: bool = False,
Expand Down
6 changes: 6 additions & 0 deletions torchtune/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
from torchtune.datasets._grammar import grammar_dataset
from torchtune.datasets._hh_rlhf_helpful import hh_rlhf_helpful_dataset
from torchtune.datasets._instruct import instruct_dataset
from torchtune.datasets._iterable_packed import (
IterablePackedDataset,
TextPackingStrategy,
)
from torchtune.datasets._packed import PackedDataset
from torchtune.datasets._preference import preference_dataset, PreferenceDataset
from torchtune.datasets._samsum import samsum_dataset
Expand Down Expand Up @@ -44,4 +48,6 @@
"SFTDataset",
"hh_rlhf_helpful_dataset",
"multimodal",
"IterablePackedDataset",
"TextPackingStrategy",
]
Loading
Loading