Skip to content

Commit 1626ddd

Browse files
Add Seq Packing in NeMo / Neva2 (NVIDIA-NeMo#11633)
* api updates and fixes Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com> * Apply isort and black reformatting Signed-off-by: yaoyu-33 <yaoyu-33@users.noreply.github.com> * fix Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com> * fix arg Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com> * update seq packing in mock ds Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com> * save Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com> * update preprocess_data Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com> * update seq packing Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com> * Apply isort and black reformatting Signed-off-by: yaoyu-33 <yaoyu-33@users.noreply.github.com> * fix sp Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com> * save Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com> * fix seq packing Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com> * add truncation and padding Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com> * Apply isort and black reformatting Signed-off-by: yaoyu-33 <yaoyu-33@users.noreply.github.com> * Fix issues Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com> * change LLaVATemplateConfig variables to class variables * change to use field with default attributes * Apply isort and black reformatting Signed-off-by: yashaswikarnati <yashaswikarnati@users.noreply.github.com> * Apply isort and black reformatting Signed-off-by: yaoyu-33 <yaoyu-33@users.noreply.github.com> * Add seq packing option in energon Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com> * Fix energon conversation Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com> * add energon option in neva training script Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com> * Apply isort and black reformatting Signed-off-by: yaoyu-33 <yaoyu-33@users.noreply.github.com> * add ci test for packed seq Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com> * fix mock dataset seq packing Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com> * Apply isort and black reformatting Signed-off-by: yaoyu-33 <yaoyu-33@users.noreply.github.com> * fix mock dataset seq packing Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com> * Apply isort and black reformatting Signed-off-by: yaoyu-33 <yaoyu-33@users.noreply.github.com> * fix lint and update seq pack func Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com> * fix energon module Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com> * Apply isort and black reformatting Signed-off-by: yaoyu-33 <yaoyu-33@users.noreply.github.com> * fix comments Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com> * Apply isort and black reformatting Signed-off-by: yaoyu-33 <yaoyu-33@users.noreply.github.com> * address lightning issues Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com> * Apply isort and black reformatting Signed-off-by: yaoyu-33 <yaoyu-33@users.noreply.github.com> * Update sequence_packing.py Signed-off-by: Yu Yao <54727607+yaoyu-33@users.noreply.github.com> * update energon requirements Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com> * Fix for energon update Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com> * fix for test Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com> --------- Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com> Signed-off-by: yaoyu-33 <yaoyu-33@users.noreply.github.com> Signed-off-by: yashaswikarnati <yashaswikarnati@users.noreply.github.com> Signed-off-by: Yu Yao <54727607+yaoyu-33@users.noreply.github.com> Co-authored-by: yaoyu-33 <yaoyu-33@users.noreply.github.com> Co-authored-by: ykarnati <ykarnati@nvidia.com> Co-authored-by: yashaswikarnati <yashaswikarnati@users.noreply.github.com>
1 parent 3591cf8 commit 1626ddd

File tree

24 files changed

+611
-138
lines changed

24 files changed

+611
-138
lines changed

.github/workflows/cicd-main.yml

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4329,11 +4329,24 @@ jobs:
43294329
with:
43304330
RUNNER: self-hosted-azure
43314331
SCRIPT: |
4332-
python tests/collections/vlm/neva_train.py \
4332+
python tests/collections/vlm/test_neva_train.py \
43334333
--devices=1 \
43344334
--max-steps=5 \
43354335
--experiment-dir=/tmp/nemo2_neva_results/${{ github.run_id }}
43364336
4337+
L2_NeMo_2_NEVA_MOCK_PACKED_TRAINING:
4338+
needs: [cicd-test-container-setup]
4339+
uses: ./.github/workflows/_test_template.yml
4340+
if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_NeMo_2_NEVA_MOCK_PACKED_TRAINING') || needs.cicd-test-container-setup.outputs.all == 'true'
4341+
with:
4342+
RUNNER: self-hosted-azure
4343+
SCRIPT: |
4344+
python tests/collections/vlm/test_neva_train.py \
4345+
--devices=1 \
4346+
--max-steps=5 \
4347+
--experiment-dir=/tmp/nemo2_neva_results/${{ github.run_id }} \
4348+
--use_packed_sequence
4349+
43374350
L2_NeMo_2_MLLAMA_MOCK_TRAINING:
43384351
needs: [cicd-test-container-setup]
43394352
uses: ./.github/workflows/_test_template.yml
@@ -4342,7 +4355,7 @@ jobs:
43424355
RUNNER: self-hosted-azure
43434356
SCRIPT: |
43444357
TRANSFORMERS_OFFLINE=1 \
4345-
python tests/collections/vlm/mllama_train.py \
4358+
python tests/collections/vlm/test_mllama_train.py \
43464359
--devices=1 \
43474360
--max-steps=5 \
43484361
--experiment-dir=/tmp/nemo2_mllama_results/${{ github.run_id }}
@@ -5060,6 +5073,7 @@ jobs:
50605073
- Speech_Checkpoints_tests
50615074
- L2_Stable_Diffusion_Training
50625075
- L2_NeMo_2_NEVA_MOCK_TRAINING
5076+
- L2_NeMo_2_NEVA_MOCK_PACKED_TRAINING
50635077
- L2_NeMo_2_MLLAMA_MOCK_TRAINING
50645078
- L2_NeMo_2_GPT_Pretraining_no_transformer_engine
50655079
- L2_NeMo_2_GPT_DDP_Param_Parity_check

nemo/collections/llm/peft/api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@
1616
from pathlib import Path
1717
from typing import Tuple, Union
1818

19-
import pytorch_lightning as pl
19+
import lightning.pytorch as pl
2020
import torch
21+
from lightning.pytorch.trainer.states import TrainerFn
2122
from megatron.core import dist_checkpointing
22-
from pytorch_lightning.trainer.states import TrainerFn
2323
from rich.console import Console
2424

2525
from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer

nemo/collections/multimodal/data/energon/base.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def __init__(
6868
multimodal_sample_config: Optional[MultiModalSampleConfig] = MultiModalSampleConfig(),
6969
task_encoder: Optional[MultiModalTaskEncoder] = None,
7070
decoder_seq_length: Optional[int] = None,
71+
packing_buffer_size: Optional[int] = None,
7172
) -> None:
7273
"""
7374
Initialize the EnergonMultiModalDataModule.
@@ -84,6 +85,8 @@ def __init__(
8485
Defaults to MultiModalSampleConfig().
8586
task_encoder (MultiModalTaskEncoder, optional): Encoder responsible for encoding and batching samples.
8687
If not provided, a default (MultimodalTaskEncoder) encoder will be created. Defaults to None.
88+
decoder_seq_length (int, optional): The maximum sequence length for the decoder. Used in encoder-decoder models.
89+
packing_buffer_size (int, optional): Size of the packing buffer for batched samples. Defaults to None.
8790
"""
8891

8992
super().__init__()
@@ -113,6 +116,7 @@ def __init__(
113116
)
114117
self.train_dataloader_object = None
115118
self.val_dataloader_object = None
119+
self.packing_buffer_size = packing_buffer_size
116120

117121
def io_init(self, **kwargs) -> fdl.Config[Self]:
118122

@@ -146,6 +150,7 @@ def datasets_provider(self, worker_config, split: Literal['train', 'val'] = 'val
146150
task_encoder=self.task_encoder,
147151
worker_config=worker_config,
148152
max_samples_per_sequence=None,
153+
packing_buffer_size=self.packing_buffer_size,
149154
shuffle_buffer_size=100,
150155
split_part=split,
151156
)

nemo/collections/multimodal/data/energon/config.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,11 @@
1313
# limitations under the License.
1414

1515
from dataclasses import dataclass, field
16-
from typing import List
16+
from typing import List, Tuple, Union
17+
1718
import torch
19+
from megatron.core.packed_seq_params import PackedSeqParams
20+
1821
from nemo.collections.multimodal.data.energon.conversation import LLaVATemplateConfig
1922

2023

@@ -34,7 +37,7 @@ class ImageToken(MultiModalToken):
3437

3538
@dataclass
3639
class ImageTextSample:
37-
'''Sample type for template formatted raw image text sample'''
40+
"""Sample type for template formatted raw image text sample"""
3841

3942
__key__: str = ''
4043
images: torch.Tensor = field(default_factory=lambda: torch.empty(0))
@@ -43,6 +46,15 @@ class ImageTextSample:
4346
loss_mask: torch.Tensor = field(default_factory=lambda: torch.empty(0, dtype=torch.float))
4447

4548

49+
@dataclass
50+
class PackedImageTextSample(ImageTextSample):
51+
"""Sample type for packed image text sample"""
52+
53+
__restore_key__: Tuple[Union[str, int, tuple], ...] = ()
54+
position_ids: torch.Tensor = field(default_factory=lambda: torch.empty(0, dtype=torch.float))
55+
packed_seq_params: PackedSeqParams = field(default_factory=lambda: PackedSeqParams())
56+
57+
4658
@dataclass
4759
class ImageTextRawBatch:
4860
"""Sample type for image text raw batch"""
@@ -56,6 +68,14 @@ class ImageTextRawBatch:
5668
loss_mask: torch.Tensor = field(default_factory=lambda: torch.empty(0, dtype=torch.float))
5769

5870

71+
@dataclass
72+
class PackedImageTextRawBatch(ImageTextRawBatch):
73+
"""Sample type for image text raw batch"""
74+
75+
position_ids: torch.Tensor = field(default_factory=lambda: torch.empty(0, dtype=torch.float))
76+
packed_seq_params: PackedSeqParams = field(default_factory=lambda: PackedSeqParams())
77+
78+
5979
@dataclass
6080
class MultiModalSampleConfig:
6181
image_token: ImageToken = field(default_factory=ImageToken)

nemo/collections/multimodal/data/energon/conversation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class LLaVATemplateConfig(BaseConversationTemplateConfig):
3030
"""LLava-specific template configuration which extends the base config"""
3131

3232
system: str = field(
33-
default="A chat between a curious user and artificial assistant agent. "
33+
default="A chat between a curious user and an artificial intelligence assistant. "
3434
"The assistant gives helpful, detailed and polite answers to user's questions."
3535
)
3636
roles: List[str] = field(default_factory=lambda: ['user', 'assistant'])

nemo/collections/multimodal/data/energon/task_encoder.py

Lines changed: 143 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,21 @@
2525
batch_list,
2626
batch_pad_stack,
2727
)
28+
from megatron.energon.task_encoder.base import stateless
2829

29-
from nemo.collections.multimodal.data.energon.config import ImageTextRawBatch, ImageTextSample
30+
from nemo.collections.multimodal.data.energon.config import (
31+
ImageTextRawBatch,
32+
ImageTextSample,
33+
PackedImageTextRawBatch,
34+
PackedImageTextSample,
35+
)
3036
from nemo.collections.multimodal.data.energon.sample_encoder import (
3137
InterleavedSampleEncoder,
3238
SampleEncoder,
3339
SimilarityInterleavedEncoder,
3440
VQASampleEncoder,
3541
)
42+
from nemo.utils import logging
3643

3744

3845
class MultiModalTaskEncoder(
@@ -54,16 +61,34 @@ class MultiModalTaskEncoder(
5461
for model input.
5562
"""
5663

57-
def __init__(self, tokenizer, image_processor, multimodal_sample_config):
64+
def __init__(
65+
self,
66+
tokenizer,
67+
image_processor,
68+
multimodal_sample_config,
69+
packed_sequence=False,
70+
packed_sequence_size=-1,
71+
num_image_embeddings_per_tile=576,
72+
):
5873
"""
5974
Initialize the MultiModalTaskEncoder with specific encoders for different sample types.
6075
6176
Parameters:
62-
tokenizer (Tokenizer): The tokenizer used for processing text across different sample types.
63-
image_processor (ImageProcessor): The image processor used for preprocessing images.
64-
multimodal_sample_config (MultiModalSampleConfig): MultiModalSampleConfig object.
77+
tokenizer (Tokenizer): The tokenizer used for processing textual components across sample types.
78+
image_processor (ImageProcessor): The image processor responsible for preprocessing image data.
79+
multimodal_sample_config (MultiModalSampleConfig): Configuration object defining properties and
80+
requirements for multimodal samples.
81+
packed_sequence (bool, optional): Flag indicating whether packed sequences are used. Default is False.
82+
packed_sequence_size (int, optional): The size of packed sequences, used when `packed_sequence` is True.
83+
Default is -1.
84+
num_image_embeddings_per_tile (int, optional): Number of image embeddings per image tile. Determines
85+
the granularity of image features. Default is 576.
6586
"""
6687
self.tokenizer = tokenizer
88+
self.sample_config = multimodal_sample_config
89+
self.packed_sequence = packed_sequence
90+
self.num_image_embeddings_per_tile = num_image_embeddings_per_tile # only used with seq packing
91+
self.packed_sequence_size = packed_sequence_size
6792
self.encoders: Dict[str, SampleEncoder] = {
6893
VQASample.__name__: VQASampleEncoder(
6994
tokenizer=tokenizer,
@@ -92,6 +117,7 @@ def register_encoder(self, sample_type: str, encoder: SampleEncoder) -> None:
92117
"""
93118
self.encoders[sample_type] = encoder
94119

120+
@stateless
95121
def encode_sample(
96122
self, sample: Union[VQASample, InterleavedSample, SimilarityInterleavedSample, CaptioningSample]
97123
) -> ImageTextSample:
@@ -118,7 +144,9 @@ def encode_sample(
118144
encoded_sample = encoder.encode(input_sample=sample, output_sample=ImageTextSample())
119145
return encoded_sample
120146

121-
def batch(self, samples: List[ImageTextSample]) -> ImageTextRawBatch:
147+
def batch(
148+
self, samples: List[Union[ImageTextSample, PackedImageTextSample]]
149+
) -> Union[ImageTextRawBatch, PackedImageTextRawBatch]:
122150
"""
123151
Batch a list of encoded samples into a single raw batch.
124152
@@ -131,26 +159,51 @@ def batch(self, samples: List[ImageTextSample]) -> ImageTextRawBatch:
131159
ImageTextRawBatch: The batched data, including images, tokens, labels, and loss masks.
132160
"""
133161

134-
keys, images, tokens, labels, loss_mask = [], [], [], [], []
135-
for sample in samples:
136-
keys.append(sample.__key__)
137-
images.append(sample.images)
138-
tokens.append(sample.tokens)
139-
labels.append(sample.labels)
140-
loss_mask.append(sample.loss_mask)
141-
142-
batch_keys = batch_list(keys)
143-
batch_images = batch_pad_stack(images)
144-
batch_prompt_tokens = batch_pad_stack(tokens)
145-
batch_labels = batch_pad_stack(labels)
146-
batch_loss_mask = batch_pad_stack(loss_mask)
147-
return ImageTextRawBatch(
148-
__keys__=batch_keys,
149-
images=batch_images,
150-
tokens=batch_prompt_tokens,
151-
labels=batch_labels,
152-
loss_mask=batch_loss_mask,
153-
)
162+
if self.packed_sequence:
163+
if len(samples) > 1:
164+
raise ValueError(
165+
"Micro batch size should be 1 when training with packed sequence, but your micro batch size "
166+
f"is {len(samples)}. \nThe following config is equivalent to your current setting for "
167+
f"a packed dataset. Please update your config to the following: \n"
168+
f"Set micro batch size to 1 (currently {len(samples)})\n"
169+
f"Set global batch size to `global_batch_size // {len(samples)}` "
170+
f"Set packed sequence length to `original_sample_seq_len * {len(samples)}` "
171+
f"(currently {self.packed_sequence_size}) \n"
172+
f"For details please visit "
173+
f"https://docs.nvidia.com/nemo-framework/user-guide/latest/sft_peft/packed_sequence.html"
174+
)
175+
# The batching are taken care by packing.
176+
sample = samples[0]
177+
return PackedImageTextRawBatch(
178+
__keys__=sample.__key__,
179+
images=sample.images,
180+
tokens=sample.tokens,
181+
labels=sample.labels,
182+
loss_mask=sample.loss_mask,
183+
position_ids=sample.position_ids,
184+
packed_seq_params=sample.packed_seq_params,
185+
)
186+
else:
187+
keys, images, tokens, labels, loss_mask = [], [], [], [], []
188+
for sample in samples:
189+
keys.append(sample.__key__)
190+
images.append(sample.images)
191+
tokens.append(sample.tokens)
192+
labels.append(sample.labels)
193+
loss_mask.append(sample.loss_mask)
194+
195+
batch_keys = batch_list(keys)
196+
batch_images = batch_pad_stack(images)
197+
batch_prompt_tokens = batch_pad_stack(tokens)
198+
batch_labels = batch_pad_stack(labels)
199+
batch_loss_mask = batch_pad_stack(loss_mask)
200+
return ImageTextRawBatch(
201+
__keys__=batch_keys,
202+
images=batch_images,
203+
tokens=batch_prompt_tokens,
204+
labels=batch_labels,
205+
loss_mask=batch_loss_mask,
206+
)
154207

155208
def encode_batch(self, batch_data: ImageTextRawBatch) -> dict:
156209
"""
@@ -165,7 +218,7 @@ def encode_batch(self, batch_data: ImageTextRawBatch) -> dict:
165218
Returns:
166219
dict: A dictionary containing the encoded batch data, ready for model input.
167220
"""
168-
batch_dict = dataclasses.asdict(batch_data)
221+
batch_dict = batch_data.__dict__
169222
if 'images' in batch_dict:
170223
batch_dict['media'] = batch_dict['images']
171224
del batch_dict['images']
@@ -177,3 +230,66 @@ def encode_batch(self, batch_data: ImageTextRawBatch) -> dict:
177230
if 'attention_mask' not in batch_dict:
178231
batch_dict['attention_mask'] = None
179232
return batch_dict
233+
234+
def select_samples_to_pack(self, samples):
235+
"""Selects which samples will be packed together.
236+
237+
NOTE: Energon dataloader calls this method internally if packing is used.
238+
Please see https://nvidia.github.io/Megatron-Energon/packing.html
239+
"""
240+
from nemo.collections.vlm.neva.data.sequence_packing import greedy_knapsack, predict_seq_len
241+
242+
media_token_id = self.sample_config.image_token.token_id
243+
lengths = [
244+
predict_seq_len(
245+
sample.tokens,
246+
media_token_index=media_token_id,
247+
num_image_embeddings_per_tile=self.num_image_embeddings_per_tile,
248+
)
249+
for sample in samples
250+
]
251+
packed_samples = greedy_knapsack(lengths, samples, self.packed_sequence_size)
252+
avg_samples_per_bin = round(len(lengths) / len(packed_samples))
253+
logging.info(
254+
f"[Seq Packing Info] - Packing seq len: {self.packed_sequence_size}, "
255+
f"Buffered samples: {len(lengths)}, Total number of bins: {len(packed_samples)}, "
256+
f"Average samples per bin: {avg_samples_per_bin}"
257+
)
258+
return packed_samples
259+
260+
@stateless
261+
def pack_selected_samples(self, samples):
262+
"""
263+
Function to pack a list of ImageTaskSample into a single ImageTaskSamplePacked.
264+
265+
NOTE: Energon dataloader calls this method internally if packing is used.
266+
Please see https://nvidia.github.io/Megatron-Energon/packing.html
267+
268+
Args:
269+
samples: List of ImageTaskSample instances to pack into one sample.
270+
271+
Returns:
272+
ImageTaskSamplePacked instance.
273+
"""
274+
from nemo.collections.vlm.neva.data.sequence_packing import convert_to_packed
275+
276+
packed_images = torch.stack([sample.images for sample in samples])
277+
media_token_id = self.sample_config.image_token.token_id
278+
packed_tokens, packed_labels, packed_position_ids, packed_loss_mask, packed_seq_params = convert_to_packed(
279+
tokens=[sample.tokens for sample in samples],
280+
labels=[sample.labels for sample in samples],
281+
num_image_embeddings_per_tile=self.num_image_embeddings_per_tile,
282+
media_token_index=media_token_id,
283+
ignore_index=self.sample_config.ignore_place_holder,
284+
)
285+
286+
return PackedImageTextSample(
287+
__key__=",".join([s.__key__ for s in samples]),
288+
__restore_key__=(), # Will be set by energon based on `samples`
289+
tokens=packed_tokens,
290+
labels=packed_labels,
291+
images=packed_images,
292+
position_ids=packed_position_ids,
293+
loss_mask=packed_loss_mask,
294+
packed_seq_params=packed_seq_params,
295+
)

nemo/collections/vlm/inference/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from typing import List, Optional, Union
1616

17-
import pytorch_lightning as pl
17+
import lightning.pytorch as pl
1818
import torch
1919
import torch.distributed
2020
from megatron.core.inference.common_inference_params import CommonInferenceParams

0 commit comments

Comments
 (0)