@@ -968,6 +968,100 @@ def __iter__(self):
968968 )
969969
970970
971+ def test_token_packing_dataset_padding_split_remaining_capacity_below_divisor ():
972+ """Test that split mode handles remaining capacity below pad_sequences_to_be_divisible_by.
973+
974+ When the remaining batch capacity (after rounding down to the pad divisor) is 0,
975+ the current batch must be yielded and the sample starts a new batch. Without this
976+ guard, _split_sample_by_num_tokens would be called with tokens_available=0 and crash.
977+
978+ max=12, pad=8, split=True:
979+ - s1: raw=5, padded=8. current=8 < 12. Append.
980+ - s2: raw=3, padded=8. current=8+8=16 > 12.
981+ tokens_in_batch=8, tokens_available=12-8=4, rounded to (4//8)*8=0 → yield [s1], fresh batch.
982+ - s3: raw=4, padded=8. current=8+8=16 > 12. Same: yield [s2], fresh batch.
983+ """
984+
985+ class MockDataset (torch .utils .data .IterableDataset ):
986+ def __iter__ (self ):
987+ yield {"input_ids" : list (range (5 ))} # padded to 8
988+ yield {"input_ids" : list (range (3 ))} # padded to 8
989+ yield {"input_ids" : list (range (4 ))} # padded to 8
990+
991+ dataset = MockDataset ()
992+ token_packing_dataset = TokenPackingDataset (
993+ dataset ,
994+ max_tokens_per_batch = 12 ,
995+ pad_sequences_to_be_divisible_by = 8 ,
996+ split_samples = True ,
997+ drop_last = False ,
998+ )
999+ batches = list (token_packing_dataset )
1000+
1001+ # Each sample pads to 8; only one fits per batch (8 < 12, but 8+8=16 > 12,
1002+ # and remaining capacity 4 rounds down to 0 with pad=8).
1003+ assert len (batches ) == 3
1004+ assert [len (s ["input_ids" ]) for s in batches [0 ]] == [5 ]
1005+ assert [len (s ["input_ids" ]) for s in batches [1 ]] == [3 ]
1006+ assert [len (s ["input_ids" ]) for s in batches [2 ]] == [4 ]
1007+
1008+
1009+ def test_token_packing_dataset_padding_no_split_yields_before_overflow ():
1010+ """Test that non-split mode correctly yields the batch before a padded sample overflows.
1011+
1012+ max=12, pad=8, split=False:
1013+ - s1: raw=5, padded=8. current=8 < 12. Append.
1014+ - s2: raw=3, padded=8. current=8+8=16 > 12. Yield [s1], start fresh with s2.
1015+ - s3: raw=4, padded=8. current=8+8=16 > 12. Yield [s2], start fresh with s3.
1016+ """
1017+
1018+ class MockDataset (torch .utils .data .IterableDataset ):
1019+ def __iter__ (self ):
1020+ yield {"input_ids" : list (range (5 ))} # padded to 8
1021+ yield {"input_ids" : list (range (3 ))} # padded to 8
1022+ yield {"input_ids" : list (range (4 ))} # padded to 8
1023+
1024+ dataset = MockDataset ()
1025+ token_packing_dataset = TokenPackingDataset (
1026+ dataset ,
1027+ max_tokens_per_batch = 12 ,
1028+ pad_sequences_to_be_divisible_by = 8 ,
1029+ split_samples = False ,
1030+ drop_last = False ,
1031+ )
1032+ batches = list (token_packing_dataset )
1033+
1034+ # Each sample pads to 8, only one fits per batch (8 < 12, but 8+8=16 > 12)
1035+ assert len (batches ) == 3
1036+ assert [len (s ["input_ids" ]) for s in batches [0 ]] == [5 ]
1037+ assert [len (s ["input_ids" ]) for s in batches [1 ]] == [3 ]
1038+ assert [len (s ["input_ids" ]) for s in batches [2 ]] == [4 ]
1039+
1040+
1041+ def test_token_packing_dataset_oversized_sample_raises ():
1042+ """Test that a sample exceeding max_tokens_per_batch raises a ValueError.
1043+
1044+ Users should set truncation or a maximum length in their tokenizer/dataset to ensure
1045+ all samples fit within max_tokens_per_batch.
1046+ """
1047+
1048+ class MockDataset (torch .utils .data .IterableDataset ):
1049+ def __iter__ (self ):
1050+ yield {"input_ids" : list (range (5 ))} # fits
1051+ yield {"input_ids" : list (range (25 ))} # exceeds max of 10
1052+
1053+ dataset = MockDataset ()
1054+ token_packing_dataset = TokenPackingDataset (
1055+ dataset ,
1056+ max_tokens_per_batch = 10 ,
1057+ split_samples = False ,
1058+ drop_last = False ,
1059+ )
1060+
1061+ with pytest .raises (ValueError , match = "Sample length.*exceeds max_tokens_per_batch" ):
1062+ list (token_packing_dataset )
1063+
1064+
9711065def test_token_packing_dataset_with_padding_split_drop_last_false (tokenizer ):
9721066 """Test that with drop_last=False, all batches except the last have exactly max_tokens."""
9731067 pad_divisor = 4
0 commit comments