Skip to content

Commit f543d94

Browse files
authored
Fix a bug with computing the output mask after generate (#1029)
We were calling cumsum with the wrong axis, meaning we were not correctly masking all positions after an end token.
1 parent 7e8f5a3 commit f543d94

File tree

4 files changed

+44
-6
lines changed

4 files changed

+44
-6
lines changed

keras_nlp/models/gpt2/gpt2_causal_lm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ def next(prompt, cache, index):
363363
end_locations = (token_ids == end_token_id) & (~padding_mask)
364364
end_locations = tf.cast(end_locations, tf.int32)
365365
# Use cumsum to get ones in all locations after end_locations.
366-
overflow = tf.math.cumsum(end_locations, exclusive=True)
366+
overflow = tf.math.cumsum(end_locations, exclusive=True, axis=-1)
367367
# Our padding mask is the inverse of these overflow locations.
368368
padding_mask = ~tf.cast(overflow, tf.bool)
369369
else:

keras_nlp/models/gpt2/gpt2_causal_lm_test.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414
"""Tests for GPT2 causal LM model."""
1515

1616
import os
17+
from unittest.mock import patch
1718

19+
import numpy as np
1820
import pytest
1921
import tensorflow as tf
2022
from absl.testing import parameterized
@@ -54,8 +56,8 @@ def setUp(self):
5456
vocabulary_size=self.preprocessor.tokenizer.vocabulary_size(),
5557
num_layers=2,
5658
num_heads=2,
57-
hidden_dim=64,
58-
intermediate_dim=128,
59+
hidden_dim=4,
60+
intermediate_dim=8,
5961
max_sequence_length=self.preprocessor.packer.sequence_length,
6062
)
6163
self.causal_lm = GPT2CausalLM(
@@ -118,6 +120,23 @@ def test_generate(self):
118120
self.preprocessed_batch["padding_mask"][:, :5],
119121
)
120122

123+
def test_early_stopping(self):
124+
call_with_cache = self.causal_lm.call_with_cache
125+
126+
def wrapper(*args, **kwargs):
127+
"""Modify output logits to always favor end_token_id"""
128+
logits, hidden_states, cache = call_with_cache(*args, **kwargs)
129+
logits = np.zeros(logits.shape.as_list())
130+
logits[:, :, self.preprocessor.tokenizer.end_token_id] = 1.0e9
131+
return logits, hidden_states, cache
132+
133+
with patch.object(self.causal_lm, "call_with_cache", wraps=wrapper):
134+
prompt = [" airplane at airport", " airplane"]
135+
output = self.causal_lm.generate(prompt)
136+
# We should immediately abort and output the prompt.
137+
self.assertEqual(prompt, output)
138+
self.assertEqual(self.causal_lm.call_with_cache.call_count, 2)
139+
121140
def test_generate_compilation(self):
122141
# Assert we do not recompile with successive calls.
123142
self.causal_lm.generate(self.raw_batch)

keras_nlp/models/opt/opt_causal_lm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ def next(prompt, cache, index):
358358
end_locations = (token_ids == end_token_id) & (~padding_mask)
359359
end_locations = tf.cast(end_locations, tf.int32)
360360
# Use cumsum to get ones in all locations after end_locations.
361-
overflow = tf.math.cumsum(end_locations, exclusive=True)
361+
overflow = tf.math.cumsum(end_locations, exclusive=True, axis=-1)
362362
# Our padding mask is the inverse of these overflow locations.
363363
padding_mask = ~tf.cast(overflow, tf.bool)
364364
else:

keras_nlp/models/opt/opt_causal_lm_test.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414
"""Tests for OPT causal LM model."""
1515

1616
import os
17+
from unittest.mock import patch
1718

19+
import numpy as np
1820
import pytest
1921
import tensorflow as tf
2022
from absl.testing import parameterized
@@ -60,8 +62,8 @@ def setUp(self):
6062
vocabulary_size=self.preprocessor.tokenizer.vocabulary_size(),
6163
num_layers=2,
6264
num_heads=2,
63-
hidden_dim=64,
64-
intermediate_dim=128,
65+
hidden_dim=4,
66+
intermediate_dim=8,
6567
max_sequence_length=self.preprocessor.packer.sequence_length,
6668
)
6769
self.causal_lm = OPTCausalLM(
@@ -124,6 +126,23 @@ def test_generate(self):
124126
self.preprocessed_batch["padding_mask"][:, :5],
125127
)
126128

129+
def test_early_stopping(self):
130+
call_with_cache = self.causal_lm.call_with_cache
131+
132+
def wrapper(*args, **kwargs):
133+
"""Modify output logits to always favor end_token_id"""
134+
logits, hidden_states, cache = call_with_cache(*args, **kwargs)
135+
logits = np.zeros(logits.shape.as_list())
136+
logits[:, :, self.preprocessor.tokenizer.end_token_id] = 1.0e9
137+
return logits, hidden_states, cache
138+
139+
with patch.object(self.causal_lm, "call_with_cache", wraps=wrapper):
140+
prompt = [" airplane at airport", " airplane"]
141+
output = self.causal_lm.generate(prompt)
142+
# We should immediately abort and output the prompt.
143+
self.assertEqual(prompt, output)
144+
self.assertEqual(self.causal_lm.call_with_cache.call_count, 2)
145+
127146
def test_generate_compilation(self):
128147
# Assert we do not recompile with successive calls.
129148
self.causal_lm.generate(self.raw_batch)

0 commit comments

Comments
 (0)