Skip to content

Commit b3f6bb1

Browse files
committed
Sliding window fixes (#1738)
* Add tests for sliding window issues * Fix for sliding window issues
1 parent 4d1659e commit b3f6bb1

File tree

3 files changed

+89
-29
lines changed

3 files changed

+89
-29
lines changed

keras_nlp/src/models/gemma/gemma_attention.py

Lines changed: 37 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ def _compute_attention(
122122
v,
123123
attention_mask,
124124
training=False,
125+
cache_update_index=0,
125126
):
126127
if self.query_head_dim_normalize:
127128
query_normalization = 1 / np.sqrt(self.head_dim)
@@ -152,29 +153,10 @@ def _compute_attention(
152153
)
153154

154155
if self.use_sliding_window_attention:
155-
all_ones = ops.ones_like(attention_mask)
156-
if keras.config.backend() == "tensorflow":
157-
import tensorflow as tf
158-
159-
sliding_window_size = ops.minimum(
160-
self.sliding_window_size - 1, q_len
161-
)
162-
sliding_window_size = ops.cast(
163-
sliding_window_size, dtype="int32"
164-
)
165-
sliding_mask = tf.linalg.band_part(
166-
all_ones, sliding_window_size - 1, sliding_window_size - 1
167-
)
168-
sliding_mask = ops.cast(sliding_mask, dtype="bool")
169-
bool_attention_mask = ops.cast(attention_mask, dtype="bool")
170-
attention_mask = tf.math.logical_and(
171-
sliding_mask, bool_attention_mask
172-
)
173-
else:
174-
sliding_mask = ops.triu(
175-
all_ones, -1 * self.sliding_window_size + 1
176-
) * ops.tril(all_ones, self.sliding_window_size - 1)
177-
attention_mask = sliding_mask * attention_mask
156+
attention_mask = self._mask_sliding_window(
157+
attention_mask,
158+
cache_update_index=cache_update_index,
159+
)
178160

179161
attention_mask = attention_mask[:, None, None, :, :]
180162
orig_dtype = attention_logits.dtype
@@ -189,6 +171,32 @@ def _compute_attention(
189171
results = ops.einsum("bkgts,bskh->btkgh", attention_softmax, v)
190172
return ops.reshape(results, (b, q_len, self.num_query_heads, h))
191173

174+
def _mask_sliding_window(
175+
self,
176+
attention_mask,
177+
cache_update_index=0,
178+
):
179+
batch_size, query_len, key_len = ops.shape(attention_mask)
180+
# Compute the sliding window for square attention.
181+
all_ones = ops.ones((key_len, key_len), "bool")
182+
if keras.config.backend() == "tensorflow":
183+
# TODO: trui/tril has issues with dynamic shape on the tensorflow
184+
# backend. We should fix, but use `band_part` for now.
185+
import tensorflow as tf
186+
187+
band_size = ops.minimum(key_len, self.sliding_window_size - 1)
188+
band_size = ops.cast(band_size, "int32")
189+
sliding_mask = tf.linalg.band_part(all_ones, band_size, band_size)
190+
else:
191+
sliding_mask = ops.triu(
192+
all_ones, -1 * self.sliding_window_size + 1
193+
) * ops.tril(all_ones, self.sliding_window_size - 1)
194+
# Slice the window for short queries during generation.
195+
start = (cache_update_index, 0)
196+
sliding_mask = ops.slice(sliding_mask, start, (query_len, key_len))
197+
sliding_mask = ops.expand_dims(sliding_mask, 0)
198+
return ops.logical_and(attention_mask, ops.cast(sliding_mask, "bool"))
199+
192200
def call(
193201
self,
194202
x,
@@ -216,7 +224,12 @@ def call(
216224
value = self.value_dense(x)
217225

218226
attention_vec = self._compute_attention(
219-
query, key, value, attention_mask, training=training
227+
query,
228+
key,
229+
value,
230+
attention_mask,
231+
training=training,
232+
cache_update_index=cache_update_index,
220233
)
221234

222235
# Wipe attn vec if there are no attended tokens.

keras_nlp/src/models/gemma/gemma_backbone_test.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,27 @@ def test_backbone_basics(self):
204204
expected_output_shape=(2, 10, 16),
205205
)
206206

207+
def test_sliding_window(self):
208+
# Test sliding window correctness by hand.
209+
backbone = GemmaBackbone(**self.init_kwargs)
210+
attention = backbone.transformer_layers[0].attention
211+
mask = attention._mask_sliding_window(ops.ones((1, 10, 10), "bool"))
212+
expected = [
213+
[
214+
[1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
215+
[1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
216+
[1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
217+
[1, 1, 1, 1, 1, 1, 1, 1, 0, 0],
218+
[1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
219+
[0, 1, 1, 1, 1, 1, 1, 1, 1, 1],
220+
[0, 0, 1, 1, 1, 1, 1, 1, 1, 1],
221+
[0, 0, 0, 1, 1, 1, 1, 1, 1, 1],
222+
[0, 0, 0, 0, 1, 1, 1, 1, 1, 1],
223+
[0, 0, 0, 0, 0, 1, 1, 1, 1, 1],
224+
]
225+
]
226+
self.assertAllEqual(mask, expected)
227+
207228
@pytest.mark.large
208229
def test_saved_model(self):
209230
self.run_model_saving_test(

keras_nlp/src/models/gemma/gemma_causal_lm_test.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,22 @@ def setUp(self):
3939
self.tokenizer,
4040
sequence_length=8,
4141
)
42+
# Test Gemma 2 like config, as it's the more complicated case.
4243
self.backbone = GemmaBackbone(
4344
vocabulary_size=self.preprocessor.tokenizer.vocabulary_size(),
4445
num_layers=2,
45-
num_query_heads=2,
46-
num_key_value_heads=1,
47-
hidden_dim=4,
48-
intermediate_dim=8,
46+
num_query_heads=4,
47+
num_key_value_heads=2,
48+
hidden_dim=8,
49+
intermediate_dim=16,
4950
head_dim=2,
51+
sliding_window_size=3,
52+
use_sliding_window_attention=True,
53+
attention_logit_soft_cap=50,
54+
final_logit_soft_cap=30,
55+
query_head_dim_normalize=False,
56+
use_post_ffw_norm=True,
57+
use_post_attention_norm=True,
5058
)
5159
self.init_kwargs = {
5260
"preprocessor": self.preprocessor,
@@ -63,6 +71,24 @@ def test_causal_lm_basics(self):
6371
expected_output_shape=(2, 8, 11),
6472
)
6573

74+
def test_cache_correctness(self):
75+
token_ids = self.input_data["token_ids"]
76+
padding_mask = ops.ones_like(self.input_data["padding_mask"])
77+
causal_lm = GemmaCausalLM(**self.init_kwargs)
78+
full_logits = causal_lm(
79+
{"token_ids": token_ids, "padding_mask": padding_mask}
80+
)
81+
token_ids = self.input_data["token_ids"]
82+
_, cache = causal_lm._build_cache(token_ids)
83+
cache = ops.zeros_like(cache)
84+
cached_logits = []
85+
for i in range(self.preprocessor.sequence_length):
86+
sliced = token_ids[:, i][:, None]
87+
logits, _, cache = causal_lm.call_with_cache(sliced, cache, i)
88+
cached_logits.append(logits)
89+
cached_logits = ops.concatenate(cached_logits, 1)
90+
self.assertAllClose(full_logits, cached_logits, atol=0.002)
91+
6692
def test_generate(self):
6793
causal_lm = GemmaCausalLM(**self.init_kwargs)
6894
# String input.
@@ -230,7 +256,7 @@ def test_score_layer_intercept_fn_exfiltration(self):
230256
# Setup prompts, models, and associated expected shapes.
231257
prompts = ["the quick brown fox", "the quick brown fox"]
232258
causal_lm = GemmaCausalLM(**self.init_kwargs)
233-
expected_embedded_shape = (2, 8, 4)
259+
expected_embedded_shape = (2, 8, 8)
234260
expected_score_shape = (2, 8, 11)
235261

236262
# Preprocess prompts to get tokenized representations and padding masks.

0 commit comments

Comments
 (0)