@@ -156,8 +156,11 @@ def __call__(self, features, return_tensors=None):
156156 sequence processing capabilities. When pad_to_multiple_of is used, an additional
157157 mock sequence is appended to reach the desired total length.
158158 """
159+ if return_tensors is not None and return_tensors != "pt" :
160+ raise NotImplementedError (f"Only return_tensors='pt' is supported, got '{ return_tensors } '" )
161+
159162 # Perform the masking with the BSHD collator.
160- bshd_batch = self .collator (features )
163+ bshd_batch = self .collator (features , return_tensors = return_tensors )
161164
162165 # Create the flattened batch to get the cu_seq_lens_q and cu_seq_lens_k values.
163166 packed_batch = _pt_flatten_collate (features , return_position_ids = self .return_position_ids )
@@ -279,33 +282,48 @@ def __iter__(self):
279282 samples = []
280283 current_length = 0
281284 for sample in iter (self .dataset ):
282- current_length += self ._padded_len (len (sample ["input_ids" ]))
285+ sample_length = len (sample ["input_ids" ])
286+ if sample_length > self .max_tokens_per_batch :
287+ raise ValueError (
288+ f"TokenPackingDataset: Sample length ({ sample_length } ) exceeds max_tokens_per_batch "
289+ f"({ self .max_tokens_per_batch } ). Set truncation or a maximum length in your tokenizer or dataset to"
290+ "ensure all samples fit within max_tokens_per_batch."
291+ )
292+
293+ current_length += self ._padded_len (sample_length )
283294 if current_length == self .max_tokens_per_batch :
284295 yield [* samples , sample ]
285296 samples = []
286297 current_length = 0
287298
288299 elif current_length > self .max_tokens_per_batch :
289300 if not self .split_samples :
290- # If we are not splitting samples, we can just yield the current batch (before this sample) and
291- # start a new one.
292- yield samples
301+ # Yield the current batch (before this sample) and start a new one with this sample.
302+ if samples :
303+ yield samples
293304 samples = [sample ]
294-
305+ current_length = self . _padded_len ( sample_length )
295306 else :
296- # Calculate how many padded tokens are already in the batch
297- tokens_in_batch = current_length - self ._padded_len (len ( sample [ "input_ids" ]) )
307+ # Calculate how many padded tokens are already in the batch.
308+ tokens_in_batch = current_length - self ._padded_len (sample_length )
298309 # Calculate how many tokens we can fit from this sample, ensuring the
299310 # padded length doesn't exceed the remaining capacity.
300311 tokens_available = self .max_tokens_per_batch - tokens_in_batch
301312 if self .pad_sequences_to_be_divisible_by is not None :
302313 d = self .pad_sequences_to_be_divisible_by
303314 tokens_available = (tokens_available // d ) * d
304- first_part , remaining_part = _split_sample_by_num_tokens (sample , tokens_available )
305- yield [* samples , first_part ]
306- samples = [remaining_part ]
307-
308- current_length = self ._padded_len (len (samples [0 ]["input_ids" ]))
315+ if tokens_available <= 0 :
316+ # Remaining capacity is less than pad_sequences_to_be_divisible_by;
317+ # can't fit any tokens from this sample. Yield current batch and start fresh.
318+ if samples :
319+ yield samples
320+ samples = [sample ]
321+ current_length = self ._padded_len (sample_length )
322+ else :
323+ first_part , remaining_part = _split_sample_by_num_tokens (sample , tokens_available )
324+ yield [* samples , first_part ]
325+ samples = [remaining_part ]
326+ current_length = self ._padded_len (len (samples [0 ]["input_ids" ]))
309327 else :
310328 samples .append (sample )
311329
0 commit comments