Skip to content

Commit 4ee9713

Browse files
Fix loss masking (#445)
Co-authored-by: Joel Lamy-Poirier <joel.lamy-poirier@servicenow.com>
1 parent 91b28fa commit 4ee9713

File tree

2 files changed

+268
-6
lines changed

2 files changed

+268
-6
lines changed

fast_llm/models/gpt/model.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,6 @@ def preprocess_batch(
263263
if phase != PhaseType.inference:
264264
labels_begin = tokens_begin + 1
265265
labels_end = tokens_end + self._config.head.max_prediction_distance
266-
267266
labels = batch.tokens.crop(labels_begin, labels_end).tokens
268267

269268
if batch.loss_masking_spans is not None:
@@ -272,13 +271,12 @@ def preprocess_batch(
272271
for sample_index, loss_masking_spans in enumerate(loss_masking_spans.ranges):
273272
for begin, end in loss_masking_spans:
274273
loss_mask[sample_index, begin:end] = False
275-
if (
276-
self._config.head.distillation_model is not None
277-
or self._config.decoder.block.distillation_model is not None
278-
):
279-
kwargs[LanguageModelKwargs.loss_mask] = loss_mask
280274
labels = torch.where(loss_mask, labels, -100)
281275

276+
if self._config.head.distillation_model is not None: # loss masks only used for distillation currently
277+
# loss masks contain all three sources of masking: padding, user-defined spans, image placeholders
278+
kwargs[LanguageModelKwargs.loss_mask] = labels >= 0
279+
282280
kwargs[LanguageModelKwargs.labels] = (
283281
labels.transpose(0, 1) if kwargs[AttentionKwargs.sequence_first] else labels
284282
).contiguous()

tests/test_loss_mask.py

Lines changed: 264 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,264 @@
1+
"""
2+
Integration test that loss_mask correctly combines all masking sources:
3+
- Negative labels (padding and image placeholders)
4+
- loss_masking_spans
5+
6+
Tests the actual preprocess_batch code path in fast_llm/models/gpt/model.py
7+
"""
8+
9+
import torch
10+
11+
from fast_llm.config import NoAutoValidate
12+
from fast_llm.data.sample.language_model import LanguageModelBatch
13+
from fast_llm.data.sample.range import RangeBatch
14+
from fast_llm.data.sample.token import TokenBatch
15+
from fast_llm.engine.distributed.config import PhaseType
16+
from fast_llm.layers.language_model.config import LanguageModelKwargs
17+
from fast_llm.models.gpt.config import GPTBatchConfig, GPTModelConfig
18+
from tests.utils.utils import get_base_model, requires_cuda
19+
20+
21+
def create_test_batch(
22+
tokens: torch.Tensor,
23+
lengths: list[list[int]] | None = None,
24+
loss_masking_spans: list[list[tuple[int, int]]] | None = None,
25+
) -> LanguageModelBatch:
26+
"""Create a LanguageModelBatch for testing."""
27+
token_batch = TokenBatch(tokens, lengths)
28+
29+
if loss_masking_spans is not None:
30+
range_batch = RangeBatch(loss_masking_spans, sample_size=tokens.shape[1])
31+
else:
32+
range_batch = None
33+
34+
return LanguageModelBatch(
35+
tokens=token_batch,
36+
loss_masking_spans=range_batch,
37+
)
38+
39+
40+
def get_minimal_model():
41+
"""Create a minimal GPT model for testing."""
42+
config = GPTModelConfig.from_dict(
43+
{
44+
"base_model": {
45+
"decoder": {"num_blocks": 1},
46+
"embeddings": {"vocab_size": 1000},
47+
"hidden_size": 64,
48+
},
49+
"distributed": {},
50+
},
51+
)
52+
model, distributed = get_base_model(config)
53+
return model, distributed
54+
55+
56+
def run_preprocess_batch(model, distributed_config, batch: LanguageModelBatch, phase: PhaseType = PhaseType.training):
57+
"""
58+
Run preprocess_batch with proper GPTBatchConfig metadata.
59+
60+
This avoids the code path that accesses prediction_heads directly.
61+
"""
62+
micro_batch_size, sequence_length = batch.tokens.tokens.shape
63+
64+
# Create GPTBatchConfig for metadata with proper setup
65+
with NoAutoValidate():
66+
batch_config = GPTBatchConfig(
67+
batch_size=micro_batch_size,
68+
sequence_length=sequence_length,
69+
)
70+
batch_config.setup(distributed_config)
71+
batch_config.validate()
72+
73+
# Get preprocessed metadata using GPTBatchConfig
74+
preprocessed_meta = model.preprocess_meta(batch_config, phase)
75+
76+
# Run preprocess_batch with the actual batch data
77+
return model.preprocess_batch(
78+
batch,
79+
preprocessed_meta=preprocessed_meta,
80+
phase=phase,
81+
iteration=0,
82+
)
83+
84+
85+
@requires_cuda
86+
class TestLossMaskIntegration:
87+
"""
88+
Integration tests for loss_mask computation in preprocess_batch.
89+
90+
These tests verify the masking behavior by checking labels, since:
91+
1. loss_mask = labels >= 0 (masks negative labels)
92+
2. loss_masking_spans positions are also masked
93+
3. labels are set to -100 at all masked positions
94+
95+
So if labels are -100 at expected positions, the masking is working.
96+
"""
97+
98+
def test_negative_labels_preserved(self):
99+
"""Test that negative input tokens result in negative labels (shifted by 1)."""
100+
model, distributed = get_minimal_model()
101+
102+
# Sequence: [text, text, IMG(-100), IMG(-100), text, text, text, text]
103+
# Labels (shifted by 1): [text, IMG, IMG, text, text, text, text, ?]
104+
tokens = torch.tensor(
105+
[
106+
[100, 101, -100, -100, 104, 105, 106, 107],
107+
],
108+
dtype=torch.int64,
109+
)
110+
111+
batch = create_test_batch(tokens)
112+
preprocessed = run_preprocess_batch(model, distributed.config, batch)
113+
114+
assert len(preprocessed) == 1
115+
_, kwargs = preprocessed[0]
116+
117+
labels = kwargs[LanguageModelKwargs.labels]
118+
# Flatten for easier indexing (handles sequence_first)
119+
labels_flat = labels.flatten()
120+
121+
# Labels at positions 1,2 should be -100 (the next token after positions 0,1 is -100)
122+
assert labels_flat[1].item() == -100, f"Label at position 1 should be -100, got {labels_flat[1].item()}"
123+
assert labels_flat[2].item() == -100, f"Label at position 2 should be -100, got {labels_flat[2].item()}"
124+
125+
# Labels at other positions should be positive
126+
assert labels_flat[0].item() > 0, "Label at position 0 should be positive"
127+
assert labels_flat[3].item() > 0, "Label at position 3 should be positive"
128+
129+
def test_loss_masking_spans_set_labels_to_negative(self):
130+
"""Test that loss_masking_spans positions have labels set to -100."""
131+
model, distributed = get_minimal_model()
132+
133+
# All positive tokens
134+
tokens = torch.tensor(
135+
[
136+
[100, 101, 102, 103, 104, 105, 106, 107],
137+
],
138+
dtype=torch.int64,
139+
)
140+
141+
# loss_masking_spans are in TOKEN space, but labels are shifted by 1
142+
# Span (3, 5) in token space -> after cropping with labels_begin=1 -> (2, 4) in label space
143+
# This will mask label positions 2 and 3
144+
loss_masking_spans = [[(3, 5)]]
145+
146+
batch = create_test_batch(tokens, loss_masking_spans=loss_masking_spans)
147+
preprocessed = run_preprocess_batch(model, distributed.config, batch)
148+
149+
assert len(preprocessed) == 1
150+
_, kwargs = preprocessed[0]
151+
152+
labels = kwargs[LanguageModelKwargs.labels]
153+
labels_flat = labels.flatten()
154+
155+
# After cropping, positions 2,3 in label space should be masked (set to -100)
156+
assert labels_flat[2].item() == -100, f"Label at position 2 should be -100, got {labels_flat[2].item()}"
157+
assert labels_flat[3].item() == -100, f"Label at position 3 should be -100, got {labels_flat[3].item()}"
158+
159+
# Positions outside the span should be positive
160+
assert labels_flat[0].item() > 0, "Label at position 0 should be positive"
161+
assert labels_flat[1].item() > 0, "Label at position 1 should be positive"
162+
assert labels_flat[4].item() > 0, "Label at position 4 should be positive"
163+
164+
def test_combined_masking_negative_labels_and_spans(self):
165+
"""Test that both negative labels AND loss_masking_spans result in -100 labels."""
166+
model, distributed = get_minimal_model()
167+
168+
# Tokens with -100 at positions 4,5 (will affect labels at 3,4)
169+
tokens = torch.tensor(
170+
[
171+
[100, 101, 102, 103, -100, -100, 106, 107],
172+
],
173+
dtype=torch.int64,
174+
)
175+
176+
# loss_masking_spans in token space: (2, 3) -> after cropping to label space: (1, 2)
177+
# This will mask label position 1
178+
loss_masking_spans = [[(2, 3)]]
179+
180+
batch = create_test_batch(tokens, loss_masking_spans=loss_masking_spans)
181+
preprocessed = run_preprocess_batch(model, distributed.config, batch)
182+
183+
assert len(preprocessed) == 1
184+
_, kwargs = preprocessed[0]
185+
186+
labels = kwargs[LanguageModelKwargs.labels]
187+
labels_flat = labels.flatten()
188+
189+
# Position 1 should be -100 (from loss_masking_spans after cropping)
190+
assert labels_flat[1].item() == -100, f"Position 1 should be -100 (from spans), got {labels_flat[1].item()}"
191+
192+
# Positions 3,4 should be -100 (from negative input tokens at positions 4,5)
193+
assert labels_flat[3].item() == -100, f"Position 3 should be -100 (from IMG), got {labels_flat[3].item()}"
194+
assert labels_flat[4].item() == -100, f"Position 4 should be -100 (from IMG), got {labels_flat[4].item()}"
195+
196+
# Position 0, 2, 5 should be positive (not masked)
197+
assert labels_flat[0].item() > 0, "Position 0 should be positive"
198+
assert labels_flat[2].item() > 0, "Position 2 should be positive"
199+
assert labels_flat[5].item() > 0, "Position 5 should be positive"
200+
201+
def test_all_padding_sample(self):
202+
"""Test that a sample with all -100 tokens (padding) results in all -100 labels."""
203+
model, distributed = get_minimal_model()
204+
205+
# Sample 0: normal tokens
206+
# Sample 1: all padding (-100)
207+
tokens = torch.tensor(
208+
[
209+
[100, 101, 102, 103, 104, 105, 106, 107],
210+
[-100, -100, -100, -100, -100, -100, -100, -100],
211+
],
212+
dtype=torch.int64,
213+
)
214+
215+
batch = create_test_batch(tokens)
216+
preprocessed = run_preprocess_batch(model, distributed.config, batch)
217+
218+
assert len(preprocessed) == 1
219+
_, kwargs = preprocessed[0]
220+
221+
labels = kwargs[LanguageModelKwargs.labels]
222+
223+
# Get labels for sample 1 (all should be -100)
224+
# Handle sequence_first dimension ordering
225+
if labels.shape[0] > labels.shape[1]:
226+
# sequence_first=True: shape is (seq, batch)
227+
sample1_labels = labels[:, 1]
228+
else:
229+
# sequence_first=False: shape is (batch, seq)
230+
sample1_labels = labels[1, :]
231+
232+
assert torch.all(sample1_labels == -100), f"All labels in padding sample should be -100, got {sample1_labels}"
233+
234+
def test_image_placeholders_interleaved(self):
235+
"""Test realistic scenario: text, image placeholders, text interleaved."""
236+
model, distributed = get_minimal_model()
237+
238+
# Realistic sequence: [BOS, text, IMG, IMG, IMG, text, text, EOS]
239+
# Labels should be: [text, IMG(-100), IMG(-100), IMG(-100), text, text, EOS, ?]
240+
tokens = torch.tensor(
241+
[
242+
[1, 100, -100, -100, -100, 200, 201, 2],
243+
],
244+
dtype=torch.int64,
245+
)
246+
247+
batch = create_test_batch(tokens)
248+
preprocessed = run_preprocess_batch(model, distributed.config, batch)
249+
250+
assert len(preprocessed) == 1
251+
_, kwargs = preprocessed[0]
252+
253+
labels = kwargs[LanguageModelKwargs.labels]
254+
labels_flat = labels.flatten()
255+
256+
# Labels at positions 1,2,3 should be -100 (next tokens are IMG)
257+
assert labels_flat[1].item() == -100, f"Position 1 should be -100, got {labels_flat[1].item()}"
258+
assert labels_flat[2].item() == -100, f"Position 2 should be -100, got {labels_flat[2].item()}"
259+
assert labels_flat[3].item() == -100, f"Position 3 should be -100, got {labels_flat[3].item()}"
260+
261+
# Labels at positions 0, 4, 5 should be positive
262+
assert labels_flat[0].item() > 0, f"Position 0 should be positive, got {labels_flat[0].item()}"
263+
assert labels_flat[4].item() > 0, f"Position 4 should be positive, got {labels_flat[4].item()}"
264+
assert labels_flat[5].item() > 0, f"Position 5 should be positive, got {labels_flat[5].item()}"

0 commit comments

Comments
 (0)