2525 batch_list ,
2626 batch_pad_stack ,
2727)
28+ from megatron .energon .task_encoder .base import stateless
2829
29- from nemo .collections .multimodal .data .energon .config import ImageTextRawBatch , ImageTextSample
30+ from nemo .collections .multimodal .data .energon .config import (
31+ ImageTextRawBatch ,
32+ ImageTextSample ,
33+ PackedImageTextRawBatch ,
34+ PackedImageTextSample ,
35+ )
3036from nemo .collections .multimodal .data .energon .sample_encoder import (
3137 InterleavedSampleEncoder ,
3238 SampleEncoder ,
3339 SimilarityInterleavedEncoder ,
3440 VQASampleEncoder ,
3541)
42+ from nemo .utils import logging
3643
3744
3845class MultiModalTaskEncoder (
@@ -54,16 +61,34 @@ class MultiModalTaskEncoder(
5461 for model input.
5562 """
5663
57- def __init__ (self , tokenizer , image_processor , multimodal_sample_config ):
64+ def __init__ (
65+ self ,
66+ tokenizer ,
67+ image_processor ,
68+ multimodal_sample_config ,
69+ packed_sequence = False ,
70+ packed_sequence_size = - 1 ,
71+ num_image_embeddings_per_tile = 576 ,
72+ ):
5873 """
5974 Initialize the MultiModalTaskEncoder with specific encoders for different sample types.
6075
6176 Parameters:
62- tokenizer (Tokenizer): The tokenizer used for processing text across different sample types.
63- image_processor (ImageProcessor): The image processor used for preprocessing images.
64- multimodal_sample_config (MultiModalSampleConfig): MultiModalSampleConfig object.
77+ tokenizer (Tokenizer): The tokenizer used for processing textual components across sample types.
78+ image_processor (ImageProcessor): The image processor responsible for preprocessing image data.
79+ multimodal_sample_config (MultiModalSampleConfig): Configuration object defining properties and
80+ requirements for multimodal samples.
81+ packed_sequence (bool, optional): Flag indicating whether packed sequences are used. Default is False.
82+ packed_sequence_size (int, optional): The size of packed sequences, used when `packed_sequence` is True.
83+ Default is -1.
84+ num_image_embeddings_per_tile (int, optional): Number of image embeddings per image tile. Determines
85+ the granularity of image features. Default is 576.
6586 """
6687 self .tokenizer = tokenizer
88+ self .sample_config = multimodal_sample_config
89+ self .packed_sequence = packed_sequence
90+ self .num_image_embeddings_per_tile = num_image_embeddings_per_tile # only used with seq packing
91+ self .packed_sequence_size = packed_sequence_size
6792 self .encoders : Dict [str , SampleEncoder ] = {
6893 VQASample .__name__ : VQASampleEncoder (
6994 tokenizer = tokenizer ,
@@ -92,6 +117,7 @@ def register_encoder(self, sample_type: str, encoder: SampleEncoder) -> None:
92117 """
93118 self .encoders [sample_type ] = encoder
94119
120+ @stateless
95121 def encode_sample (
96122 self , sample : Union [VQASample , InterleavedSample , SimilarityInterleavedSample , CaptioningSample ]
97123 ) -> ImageTextSample :
@@ -118,7 +144,9 @@ def encode_sample(
118144 encoded_sample = encoder .encode (input_sample = sample , output_sample = ImageTextSample ())
119145 return encoded_sample
120146
121- def batch (self , samples : List [ImageTextSample ]) -> ImageTextRawBatch :
147+ def batch (
148+ self , samples : List [Union [ImageTextSample , PackedImageTextSample ]]
149+ ) -> Union [ImageTextRawBatch , PackedImageTextRawBatch ]:
122150 """
123151 Batch a list of encoded samples into a single raw batch.
124152
@@ -131,26 +159,51 @@ def batch(self, samples: List[ImageTextSample]) -> ImageTextRawBatch:
131159 ImageTextRawBatch: The batched data, including images, tokens, labels, and loss masks.
132160 """
133161
134- keys , images , tokens , labels , loss_mask = [], [], [], [], []
135- for sample in samples :
136- keys .append (sample .__key__ )
137- images .append (sample .images )
138- tokens .append (sample .tokens )
139- labels .append (sample .labels )
140- loss_mask .append (sample .loss_mask )
141-
142- batch_keys = batch_list (keys )
143- batch_images = batch_pad_stack (images )
144- batch_prompt_tokens = batch_pad_stack (tokens )
145- batch_labels = batch_pad_stack (labels )
146- batch_loss_mask = batch_pad_stack (loss_mask )
147- return ImageTextRawBatch (
148- __keys__ = batch_keys ,
149- images = batch_images ,
150- tokens = batch_prompt_tokens ,
151- labels = batch_labels ,
152- loss_mask = batch_loss_mask ,
153- )
162+ if self .packed_sequence :
163+ if len (samples ) > 1 :
164+ raise ValueError (
165+ "Micro batch size should be 1 when training with packed sequence, but your micro batch size "
166+ f"is { len (samples )} . \n The following config is equivalent to your current setting for "
167+ f"a packed dataset. Please update your config to the following: \n "
168+ f"Set micro batch size to 1 (currently { len (samples )} )\n "
169+ f"Set global batch size to `global_batch_size // { len (samples )} ` "
170+ f"Set packed sequence length to `original_sample_seq_len * { len (samples )} ` "
171+ f"(currently { self .packed_sequence_size } ) \n "
172+ f"For details please visit "
173+ f"https://docs.nvidia.com/nemo-framework/user-guide/latest/sft_peft/packed_sequence.html"
174+ )
175+ # The batching are taken care by packing.
176+ sample = samples [0 ]
177+ return PackedImageTextRawBatch (
178+ __keys__ = sample .__key__ ,
179+ images = sample .images ,
180+ tokens = sample .tokens ,
181+ labels = sample .labels ,
182+ loss_mask = sample .loss_mask ,
183+ position_ids = sample .position_ids ,
184+ packed_seq_params = sample .packed_seq_params ,
185+ )
186+ else :
187+ keys , images , tokens , labels , loss_mask = [], [], [], [], []
188+ for sample in samples :
189+ keys .append (sample .__key__ )
190+ images .append (sample .images )
191+ tokens .append (sample .tokens )
192+ labels .append (sample .labels )
193+ loss_mask .append (sample .loss_mask )
194+
195+ batch_keys = batch_list (keys )
196+ batch_images = batch_pad_stack (images )
197+ batch_prompt_tokens = batch_pad_stack (tokens )
198+ batch_labels = batch_pad_stack (labels )
199+ batch_loss_mask = batch_pad_stack (loss_mask )
200+ return ImageTextRawBatch (
201+ __keys__ = batch_keys ,
202+ images = batch_images ,
203+ tokens = batch_prompt_tokens ,
204+ labels = batch_labels ,
205+ loss_mask = batch_loss_mask ,
206+ )
154207
155208 def encode_batch (self , batch_data : ImageTextRawBatch ) -> dict :
156209 """
@@ -165,7 +218,7 @@ def encode_batch(self, batch_data: ImageTextRawBatch) -> dict:
165218 Returns:
166219 dict: A dictionary containing the encoded batch data, ready for model input.
167220 """
168- batch_dict = dataclasses . asdict ( batch_data )
221+ batch_dict = batch_data . __dict__
169222 if 'images' in batch_dict :
170223 batch_dict ['media' ] = batch_dict ['images' ]
171224 del batch_dict ['images' ]
@@ -177,3 +230,66 @@ def encode_batch(self, batch_data: ImageTextRawBatch) -> dict:
177230 if 'attention_mask' not in batch_dict :
178231 batch_dict ['attention_mask' ] = None
179232 return batch_dict
233+
234+ def select_samples_to_pack (self , samples ):
235+ """Selects which samples will be packed together.
236+
237+ NOTE: Energon dataloader calls this method internally if packing is used.
238+ Please see https://nvidia.github.io/Megatron-Energon/packing.html
239+ """
240+ from nemo .collections .vlm .neva .data .sequence_packing import greedy_knapsack , predict_seq_len
241+
242+ media_token_id = self .sample_config .image_token .token_id
243+ lengths = [
244+ predict_seq_len (
245+ sample .tokens ,
246+ media_token_index = media_token_id ,
247+ num_image_embeddings_per_tile = self .num_image_embeddings_per_tile ,
248+ )
249+ for sample in samples
250+ ]
251+ packed_samples = greedy_knapsack (lengths , samples , self .packed_sequence_size )
252+ avg_samples_per_bin = round (len (lengths ) / len (packed_samples ))
253+ logging .info (
254+ f"[Seq Packing Info] - Packing seq len: { self .packed_sequence_size } , "
255+ f"Buffered samples: { len (lengths )} , Total number of bins: { len (packed_samples )} , "
256+ f"Average samples per bin: { avg_samples_per_bin } "
257+ )
258+ return packed_samples
259+
260+ @stateless
261+ def pack_selected_samples (self , samples ):
262+ """
263+ Function to pack a list of ImageTaskSample into a single ImageTaskSamplePacked.
264+
265+ NOTE: Energon dataloader calls this method internally if packing is used.
266+ Please see https://nvidia.github.io/Megatron-Energon/packing.html
267+
268+ Args:
269+ samples: List of ImageTaskSample instances to pack into one sample.
270+
271+ Returns:
272+ ImageTaskSamplePacked instance.
273+ """
274+ from nemo .collections .vlm .neva .data .sequence_packing import convert_to_packed
275+
276+ packed_images = torch .stack ([sample .images for sample in samples ])
277+ media_token_id = self .sample_config .image_token .token_id
278+ packed_tokens , packed_labels , packed_position_ids , packed_loss_mask , packed_seq_params = convert_to_packed (
279+ tokens = [sample .tokens for sample in samples ],
280+ labels = [sample .labels for sample in samples ],
281+ num_image_embeddings_per_tile = self .num_image_embeddings_per_tile ,
282+ media_token_index = media_token_id ,
283+ ignore_index = self .sample_config .ignore_place_holder ,
284+ )
285+
286+ return PackedImageTextSample (
287+ __key__ = "," .join ([s .__key__ for s in samples ]),
288+ __restore_key__ = (), # Will be set by energon based on `samples`
289+ tokens = packed_tokens ,
290+ labels = packed_labels ,
291+ images = packed_images ,
292+ position_ids = packed_position_ids ,
293+ loss_mask = packed_loss_mask ,
294+ packed_seq_params = packed_seq_params ,
295+ )
0 commit comments