-
Notifications
You must be signed in to change notification settings - Fork 680
[RFC] on-the-fly packing #2819
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
felipemello1
wants to merge
57
commits into
meta-pytorch:impl-step-based-ckpt
Choose a base branch
from
felipemello1:online_packing
base: impl-step-based-ckpt
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
[RFC] on-the-fly packing #2819
Changes from 16 commits
Commits
Show all changes
57 commits
Select commit
Hold shift + click to select a range
901723f
rfc
ab02d75
Revert "rfc"
1c0117b
Merge branch 'main' of https://github.com/pytorch/torchtune
1cc6946
Merge branch 'main' of https://github.com/pytorch/torchtune
2a2efa2
add packed functions
felipemello1 6e14b06
enable on full recipe
felipemello1 f9db469
fix imports + formatting
ff6fdbe
add max_steps_per_epoch requirement
felipemello1 5e447ab
address blockers
7a1dfa5
Merge branch 'main' into online_packing
1ffd459
Merge branch 'main' into online_packing
13cda28
small fixes
felipemello1 d26769c
add md doc
felipemello1 20cfa80
Merge remote-tracking branch 'refs/remotes/origin/online_packing' int…
felipemello1 59b8cab
update comments
felipemello1 5d7d496
update comments
felipemello1 e193926
update comment
felipemello1 40d79f4
update comment
felipemello1 3cab533
first commit
felipemello1 2212b19
update tests
felipemello1 4345832
Merge remote-tracking branch 'joecummings/impl-step-based-ckpt' into …
felipemello1 2eb68b6
linter
2e51e04
tests pass
93fa743
it works
aa9e6f4
remove code
a5e7234
Merge branch 'iterable_dataset_final' into online_packing
55be775
adjust pack to have metrics
382c4e9
remove comment
5b188ed
update metrics to use handlers
felipemello1 2eab08d
remove file after refactoring
felipemello1 58491f1
add distributed tsts
felipemello1 da7245d
Merge branch 'iterable_dataset_final' of github.com:felipemello1/torc…
96424d0
tests pass
853147b
optimize SFTOutputTransform
96bc317
use ds.sampling_weight
felipemello1 3c9d161
add sampling log to interlead dataset
felipemello1 4804663
fix nested interleave
felipemello1 2fe4b40
changes to TuneIterableDataset
felipemello1 72211c9
add IterableDataset back
b350ac7
nested interleaved + dataset.info
felipemello1 f9a1aec
nits hf_iterable
felipemello1 f7a3aa7
update readme
felipemello1 17878bf
make metric dataset name explicit
felipemello1 101e96e
update recipe to share log freq + validagtion msg
felipemello1 1b3f3fc
update interleaved tests to do nesting
fac3fd5
lint
29ba1cb
error if duplicated metric name
f89eefe
improve docs
de942bf
Merge branch 'iterable_dataset_final' into online_packing
felipemello1 d6680b7
rename from strategy to packer
felipemello1 d3be015
tensors instead of lists
felipemello1 c8bfbb2
tests
felipemello1 fd41842
docs
felipemello1 734128e
tests + lint pass
23bd9fb
test collate + dataloader
felipemello1 fb7b9aa
clean up
4c505e0
improve packed testing
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
### What: | ||
Packing is the process of putting together samples until a certain target size is reached. This is done to reduce the number of padding tokens in a batch. To avoid contamination between samples, we use a document-level causal mask. To make it faster, we use flex attention to handle the special mask. | ||
|
||
Example: | ||
```python | ||
# The current pack with one sample | ||
pack = {"tokens": [1, 2], "labels": [3, 4], "document_ids": [0, 0], "input_pos": [0, 1]} | ||
|
||
# The next sample to be added | ||
sample = {"tokens": [5, 6], "labels": [7, 8]} | ||
|
||
# After adding the sample | ||
added_docs = add_sample_to_pack(pack, sample, next_doc_id=1) | ||
print(pack) | ||
>>> {"tokens": [1, 2, 5, 6], | ||
"labels": [3, 4, 7, 8], | ||
"document_ids": [0, 0, 1, 1], | ||
"input_pos": [0, 1, 0, 1]} | ||
|
||
create_block_causal_mask(document_ids) | ||
>>> [ | ||
[1, 0, 0, 0], | ||
[1, 1, 0, 0], | ||
[0, 0, 1, 0], | ||
[0, 0, 1, 1], | ||
] | ||
``` | ||
|
||
### Goal: | ||
0) Make packing a first-class citizen in TorchTune, available for all sorts of models and recipes. | ||
|
||
### Context: | ||
1) We currently have map-style packing. We pre-process the dataset before training, which is not scalable. | ||
2) Packing is only present for SFT + text data. There is no contract for how to extend it to multimodal, DPO, etc. | ||
3) Collate function has to be aware of packing logic. This is currently hardcoded in the recipe with if/else. | ||
|
||
### Solution: | ||
4) Implement a new on-the-fly packing that takes any iterable dataset as input; | ||
5) Packing contract consists of | ||
i) a `PackingStrategy` that defines a) how to pack and b) the **_mask_mod** used for flex attention; | ||
ii) a `IterablePackedDataset` that takes any a) `PackingStrategy`, b) **iterable dataset** as input and yields packed samples; | ||
iii) a `packed_collate_fn` that takes the batch of packed samples and a **mask_fn** (e.g. `strategy.create_block_mask`) to generate the attention mask on the fly. | ||
To define a new packing strategy, the user only needs to implement the `PackingStrategy` class. | ||
|
||
### Implementation: | ||
6) Updated `full_finetune_distributed.py` to use `IterablePackedDataset` when packing is enabled. There are challenges related to iterable datasets and this will be tackled in a separate iterable dataset PR. Changes made were to enable it to run for this RFC. | ||
|
||
### Not in this PR: | ||
7) **Logging**: Since we cannot do len(iterable_dataset), we need to add proper logging/metadata to assist users in understanding how far along they are on each dataset and metrics regarding the samples (avg num tokens, avg num samples / pack, etc.) | ||
8) **Packing-aware Loss**: For SFT, the same loss works for map-style and packing. This is not the case for DPO/GRPO, which would need different masking. Future work will have to handle how to associate packing with a loss that supports it. | ||
9) **Packing-aware metrics**: Advanced metrics, such as logprob per sample, would require to be aware of packing; |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,36 @@ | |
from torchtune.modules.attention_utils import packed_block_causal_mask | ||
|
||
|
||
def collate_packed( | ||
batch: list[dict[str, torch.Tensor]], mask_fn: callable, device: str | ||
) -> dict[str, torch.Tensor]: | ||
""" | ||
Generic collate function for packed samples from an IterablePackedDataset. | ||
This function handles tensor stacking and delegates attention mask creation | ||
to a provided `mask_fn`. | ||
""" | ||
if not batch: | ||
return {} | ||
|
||
# Assumes all samples in the batch have the same keys, which are all tensors. | ||
keys_to_stack = batch[0].keys() | ||
collated = {} | ||
for key in keys_to_stack: | ||
if isinstance(batch[0][key], torch.Tensor): | ||
collated[key] = torch.stack([sample[key] for sample in batch], dim=0) | ||
else: | ||
# TODO: Remove? i dont see a situation where it would not be a tensor. | ||
|
||
collated[key] = [sample[key] for sample in batch] | ||
|
||
# Delegate mask creation to the provided specialized function | ||
# TODO: investigate the need for device here. Currently we hardcode it in utilities to cuda. | ||
# shouldnt we just send to device later? | ||
collated["mask"] = mask_fn(collated["document_ids"], device=device) | ||
|
||
return collated | ||
|
||
|
||
def left_pad_sequence( | ||
sequences: list[torch.Tensor], | ||
batch_first: bool = False, | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.