@@ -64,9 +64,9 @@ def test_basic_padding(self):
6464
6565 result = self .collator (examples )
6666
67+ self .assertEqual (set (result .keys ()), {"input_ids" , "attention_mask" , "labels" })
6768 torch .testing .assert_close (result ["input_ids" ], torch .tensor ([[1 , 2 , 3 ], [4 , 5 , 0 ]]))
6869 torch .testing .assert_close (result ["attention_mask" ], torch .tensor ([[1 , 1 , 1 ], [1 , 1 , 0 ]]))
69- torch .testing .assert_close (result ["position_ids" ], torch .tensor ([[0 , 1 , 2 ], [0 , 1 , 0 ]]))
7070 torch .testing .assert_close (result ["labels" ], torch .tensor ([[1 , 2 , 3 ], [4 , 5 , - 100 ]]))
7171
7272 def test_completion_mask (self ):
@@ -79,9 +79,9 @@ def test_completion_mask(self):
7979
8080 result = self .collator (examples )
8181
82+ self .assertEqual (set (result .keys ()), {"input_ids" , "attention_mask" , "labels" })
8283 torch .testing .assert_close (result ["input_ids" ], torch .tensor ([[1 , 2 , 3 ], [4 , 5 , 0 ]]))
8384 torch .testing .assert_close (result ["attention_mask" ], torch .tensor ([[1 , 1 , 1 ], [1 , 1 , 0 ]]))
84- torch .testing .assert_close (result ["position_ids" ], torch .tensor ([[0 , 1 , 2 ], [0 , 1 , 0 ]]))
8585 torch .testing .assert_close (result ["labels" ], torch .tensor ([[- 100 , 2 , 3 ], [- 100 , 5 , - 100 ]]))
8686
8787 def test_completion_only_loss_disabled (self ):
@@ -95,9 +95,9 @@ def test_completion_only_loss_disabled(self):
9595 result = collator (examples )
9696
9797 # Labels should not be masked when completion_only_loss=False
98+ self .assertEqual (set (result .keys ()), {"input_ids" , "attention_mask" , "labels" })
9899 torch .testing .assert_close (result ["input_ids" ], torch .tensor ([[1 , 2 , 3 ], [4 , 5 , 0 ]]))
99100 torch .testing .assert_close (result ["attention_mask" ], torch .tensor ([[1 , 1 , 1 ], [1 , 1 , 0 ]]))
100- torch .testing .assert_close (result ["position_ids" ], torch .tensor ([[0 , 1 , 2 ], [0 , 1 , 0 ]]))
101101 torch .testing .assert_close (result ["labels" ], torch .tensor ([[1 , 2 , 3 ], [4 , 5 , - 100 ]]))
102102
103103 def test_padding_free_mode (self ):
@@ -107,72 +107,42 @@ def test_padding_free_mode(self):
107107
108108 result = collator (examples )
109109
110+ self .assertEqual (set (result .keys ()), {"input_ids" , "position_ids" , "labels" })
110111 torch .testing .assert_close (result ["input_ids" ], torch .tensor ([[1 , 2 , 3 , 4 , 5 ]]))
111- torch .testing .assert_close (result ["attention_mask" ], torch .tensor ([[1 , 1 , 1 , 1 , 1 ]]))
112112 torch .testing .assert_close (result ["position_ids" ], torch .tensor ([[0 , 1 , 2 , 0 , 1 ]]))
113- torch .testing .assert_close (result ["labels" ], torch .tensor ([[1 , 2 , 3 , 4 , 5 ]]))
113+ torch .testing .assert_close (result ["labels" ], torch .tensor ([[- 100 , 2 , 3 , - 100 , 5 ]]))
114114
115115 def test_padding_free_with_completion_mask (self ):
116116 """Test padding-free mode with completion masks."""
117117 collator = DataCollatorForLanguageModeling (pad_token_id = 0 , padding_free = True )
118118 examples = [
119- {"input_ids" : [1 , 2 , 3 ], "completion_mask" : [0 , 1 , 1 ]},
119+ {"input_ids" : [1 , 2 , 3 ], "completion_mask" : [0 , 0 , 1 ]},
120120 {"input_ids" : [4 , 5 ], "completion_mask" : [1 , 1 ]},
121121 ]
122122
123123 result = collator (examples )
124124
125+ self .assertEqual (set (result .keys ()), {"input_ids" , "position_ids" , "labels" })
125126 torch .testing .assert_close (result ["input_ids" ], torch .tensor ([[1 , 2 , 3 , 4 , 5 ]]))
126- torch .testing .assert_close (result ["attention_mask" ], torch .tensor ([[1 , 1 , 1 , 1 , 1 ]]))
127127 torch .testing .assert_close (result ["position_ids" ], torch .tensor ([[0 , 1 , 2 , 0 , 1 ]]))
128- torch .testing .assert_close (result ["labels" ], torch .tensor ([[- 100 , 2 , 3 , 4 , 5 ]]))
128+ torch .testing .assert_close (result ["labels" ], torch .tensor ([[- 100 , - 100 , 3 , - 100 , 5 ]]))
129129
130- def test_packing_drops_attention_mask_for_flash_attention (self ):
130+ def test_packing (self ):
131131 """Test that when using packing with position_ids, attention_mask is dropped with fa2."""
132- collator = DataCollatorForLanguageModeling (pad_token_id = 0 , padding_free = True , return_position_ids = True )
132+ collator = DataCollatorForLanguageModeling (pad_token_id = 0 , padding_free = True )
133133
134134 # Simulate packed sequences with position_ids that restart (typical of BFD packing)
135135 examples = [
136- {
137- "input_ids" : [1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 ], # Packed: [1,2,3] + [4,5] + [6,7,8]
138- "seq_lengths" : [3 , 2 , 3 ],
139- }
136+ {"input_ids" : [1 , 2 , 3 , 4 , 5 , 6 ], "seq_lengths" : [3 , 3 ]},
137+ {"input_ids" : [7 , 8 , 9 , 10 , 11 ], "seq_lengths" : [4 , 1 ]},
140138 ]
141139
142140 result = collator (examples )
143141
144- # Verify that attention_mask is NOT present - this allows FlashAttention to use position_ids
145- self .assertNotIn ("attention_mask" , result , "attention_mask should be dropped for packing with position_ids" )
146-
147- # Verify essential keys are present
148- self .assertIn ("input_ids" , result )
149- self .assertIn ("position_ids" , result )
150- self .assertIn ("labels" , result )
151-
152- # Verify the data is correctly processed
153- torch .testing .assert_close (result ["input_ids" ], torch .tensor ([[1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 ]]))
154- torch .testing .assert_close (result ["position_ids" ], torch .tensor ([[0 , 1 , 2 , 0 , 1 , 0 , 1 , 2 ]]))
155- torch .testing .assert_close (result ["labels" ], torch .tensor ([[1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 ]]))
156-
157- def test_padding_free_without_position_ids_keeps_attention_mask (self ):
158- """
159- Test that padding_free mode without explicit position_ids still creates attention_mask.
160- """
161- collator = DataCollatorForLanguageModeling (pad_token_id = 0 , padding_free = True , return_position_ids = True )
162-
163- # Examples without position_ids (not packed)
164- examples = [{"input_ids" : [1 , 2 , 3 , 4 , 5 ]}]
165-
166- result = collator (examples )
167-
168- # Should still have attention_mask since no packed position_ids
169- self .assertIn ("attention_mask" , result , "attention_mask should be present when no packed position_ids" )
170- self .assertIn ("position_ids" , result )
171- self .assertIn ("input_ids" , result )
172-
173- torch .testing .assert_close (result ["input_ids" ], torch .tensor ([[1 , 2 , 3 , 4 , 5 ]]))
174- torch .testing .assert_close (result ["attention_mask" ], torch .tensor ([[1 , 1 , 1 , 1 , 1 ]]))
175- torch .testing .assert_close (result ["position_ids" ], torch .tensor ([[0 , 1 , 2 , 3 , 4 ]]))
142+ self .assertEqual (set (result .keys ()), {"input_ids" , "position_ids" , "labels" })
143+ torch .testing .assert_close (result ["input_ids" ], torch .tensor ([[1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 ]]))
144+ torch .testing .assert_close (result ["position_ids" ], torch .tensor ([[0 , 1 , 2 , 0 , 1 , 2 , 0 , 1 , 2 , 3 , 0 ]]))
145+ torch .testing .assert_close (result ["labels" ], torch .tensor ([[- 100 , 2 , 3 , - 100 , 5 , 6 , - 100 , 8 , 9 , 10 , - 100 ]]))
176146
177147 def test_pad_to_multiple_of (self ):
178148 """Test padding to multiple of specified value."""
@@ -181,9 +151,9 @@ def test_pad_to_multiple_of(self):
181151
182152 result = collator (examples )
183153
154+ self .assertEqual (set (result .keys ()), {"input_ids" , "attention_mask" , "labels" })
184155 torch .testing .assert_close (result ["input_ids" ], torch .tensor ([[1 , 2 , 3 , 0 ], [4 , 5 , 0 , 0 ]]))
185156 torch .testing .assert_close (result ["attention_mask" ], torch .tensor ([[1 , 1 , 1 , 0 ], [1 , 1 , 0 , 0 ]]))
186- torch .testing .assert_close (result ["position_ids" ], torch .tensor ([[0 , 1 , 2 , 0 ], [0 , 1 , 0 , 0 ]]))
187157 torch .testing .assert_close (result ["labels" ], torch .tensor ([[1 , 2 , 3 , - 100 ], [4 , 5 , - 100 , - 100 ]]))
188158
189159 def test_pad_to_multiple_of_and_padding_free (self ):
@@ -193,21 +163,21 @@ def test_pad_to_multiple_of_and_padding_free(self):
193163
194164 result = collator (examples )
195165
166+ self .assertEqual (set (result .keys ()), {"input_ids" , "position_ids" , "labels" })
196167 torch .testing .assert_close (result ["input_ids" ], torch .tensor ([[1 , 2 , 3 , 4 , 5 , 0 , 0 , 0 ]]))
197- torch .testing .assert_close (result ["attention_mask" ], torch .tensor ([[1 , 1 , 1 , 1 , 1 , 0 , 0 , 0 ]]))
198168 torch .testing .assert_close (result ["position_ids" ], torch .tensor ([[0 , 1 , 2 , 0 , 1 , 0 , 0 , 0 ]]))
199- torch .testing .assert_close (result ["labels" ], torch .tensor ([[1 , 2 , 3 , 4 , 5 , - 100 , - 100 , - 100 ]]))
169+ torch .testing .assert_close (result ["labels" ], torch .tensor ([[- 100 , 2 , 3 , - 100 , 5 , - 100 , - 100 , - 100 ]]))
200170
201- def test_custom_position_ids (self ):
202- """Test handling of custom position IDs in examples ."""
171+ def test_custom_position_ids_but_no_padding_free (self ):
172+ """Test that custom position_ids are ignored if padding_free is False ."""
203173 self .collator = DataCollatorForLanguageModeling (pad_token_id = 0 )
204174 examples = [{"input_ids" : [1 , 2 , 3 ], "seq_lengths" : [1 , 2 ]}, {"input_ids" : [4 , 5 ], "seq_lengths" : [2 ]}]
205175
206176 result = self .collator (examples )
207177
178+ self .assertEqual (set (result .keys ()), {"input_ids" , "attention_mask" , "labels" })
208179 torch .testing .assert_close (result ["input_ids" ], torch .tensor ([[1 , 2 , 3 ], [4 , 5 , 0 ]]))
209180 torch .testing .assert_close (result ["attention_mask" ], torch .tensor ([[1 , 1 , 1 ], [1 , 1 , 0 ]]))
210- torch .testing .assert_close (result ["position_ids" ], torch .tensor ([[0 , 0 , 1 ], [0 , 1 , 0 ]]))
211181 torch .testing .assert_close (result ["labels" ], torch .tensor ([[1 , 2 , 3 ], [4 , 5 , - 100 ]]))
212182
213183 def test_single_example (self ):
@@ -217,9 +187,9 @@ def test_single_example(self):
217187
218188 result = self .collator (examples )
219189
190+ self .assertEqual (set (result .keys ()), {"input_ids" , "attention_mask" , "labels" })
220191 torch .testing .assert_close (result ["input_ids" ], torch .tensor ([[1 , 2 , 3 , 4 ]]))
221192 torch .testing .assert_close (result ["attention_mask" ], torch .tensor ([[1 , 1 , 1 , 1 ]]))
222- torch .testing .assert_close (result ["position_ids" ], torch .tensor ([[0 , 1 , 2 , 3 ]]))
223193 torch .testing .assert_close (result ["labels" ], torch .tensor ([[1 , 2 , 3 , 4 ]]))
224194
225195 def test_different_pad_token_id (self ):
@@ -229,9 +199,9 @@ def test_different_pad_token_id(self):
229199
230200 result = collator (examples )
231201
202+ self .assertEqual (set (result .keys ()), {"input_ids" , "attention_mask" , "labels" })
232203 torch .testing .assert_close (result ["input_ids" ], torch .tensor ([[1 , 2 , 3 ], [4 , 5 , 999 ]]))
233204 torch .testing .assert_close (result ["attention_mask" ], torch .tensor ([[1 , 1 , 1 ], [1 , 1 , 0 ]]))
234- torch .testing .assert_close (result ["position_ids" ], torch .tensor ([[0 , 1 , 2 ], [0 , 1 , 0 ]]))
235205 torch .testing .assert_close (result ["labels" ], torch .tensor ([[1 , 2 , 3 ], [4 , 5 , - 100 ]]))
236206
237207 def test_assistant_masks (self ):
@@ -246,7 +216,6 @@ def test_assistant_masks(self):
246216
247217 torch .testing .assert_close (result ["input_ids" ], torch .tensor ([[1 , 2 , 3 ], [4 , 5 , 0 ]]))
248218 torch .testing .assert_close (result ["attention_mask" ], torch .tensor ([[1 , 1 , 1 ], [1 , 1 , 0 ]]))
249- torch .testing .assert_close (result ["position_ids" ], torch .tensor ([[0 , 1 , 2 ], [0 , 1 , 0 ]]))
250219 torch .testing .assert_close (result ["labels" ], torch .tensor ([[- 100 , 2 , 3 ], [- 100 , 5 , - 100 ]]))
251220
252221 def test_single_example_single_doc (self ):
0 commit comments