@@ -122,6 +122,7 @@ def _compute_attention(
122122 v ,
123123 attention_mask ,
124124 training = False ,
125+ cache_update_index = 0 ,
125126 ):
126127 if self .query_head_dim_normalize :
127128 query_normalization = 1 / np .sqrt (self .head_dim )
@@ -152,29 +153,10 @@ def _compute_attention(
152153 )
153154
154155 if self .use_sliding_window_attention :
155- all_ones = ops .ones_like (attention_mask )
156- if keras .config .backend () == "tensorflow" :
157- import tensorflow as tf
158-
159- sliding_window_size = ops .minimum (
160- self .sliding_window_size - 1 , q_len
161- )
162- sliding_window_size = ops .cast (
163- sliding_window_size , dtype = "int32"
164- )
165- sliding_mask = tf .linalg .band_part (
166- all_ones , sliding_window_size - 1 , sliding_window_size - 1
167- )
168- sliding_mask = ops .cast (sliding_mask , dtype = "bool" )
169- bool_attention_mask = ops .cast (attention_mask , dtype = "bool" )
170- attention_mask = tf .math .logical_and (
171- sliding_mask , bool_attention_mask
172- )
173- else :
174- sliding_mask = ops .triu (
175- all_ones , - 1 * self .sliding_window_size + 1
176- ) * ops .tril (all_ones , self .sliding_window_size - 1 )
177- attention_mask = sliding_mask * attention_mask
156+ attention_mask = self ._mask_sliding_window (
157+ attention_mask ,
158+ cache_update_index = cache_update_index ,
159+ )
178160
179161 attention_mask = attention_mask [:, None , None , :, :]
180162 orig_dtype = attention_logits .dtype
@@ -189,6 +171,32 @@ def _compute_attention(
189171 results = ops .einsum ("bkgts,bskh->btkgh" , attention_softmax , v )
190172 return ops .reshape (results , (b , q_len , self .num_query_heads , h ))
191173
174+ def _mask_sliding_window (
175+ self ,
176+ attention_mask ,
177+ cache_update_index = 0 ,
178+ ):
179+ batch_size , query_len , key_len = ops .shape (attention_mask )
180+ # Compute the sliding window for square attention.
181+ all_ones = ops .ones ((key_len , key_len ), "bool" )
182+ if keras .config .backend () == "tensorflow" :
183+ # TODO: trui/tril has issues with dynamic shape on the tensorflow
184+ # backend. We should fix, but use `band_part` for now.
185+ import tensorflow as tf
186+
187+ band_size = ops .minimum (key_len , self .sliding_window_size - 1 )
188+ band_size = ops .cast (band_size , "int32" )
189+ sliding_mask = tf .linalg .band_part (all_ones , band_size , band_size )
190+ else :
191+ sliding_mask = ops .triu (
192+ all_ones , - 1 * self .sliding_window_size + 1
193+ ) * ops .tril (all_ones , self .sliding_window_size - 1 )
194+ # Slice the window for short queries during generation.
195+ start = (cache_update_index , 0 )
196+ sliding_mask = ops .slice (sliding_mask , start , (query_len , key_len ))
197+ sliding_mask = ops .expand_dims (sliding_mask , 0 )
198+ return ops .logical_and (attention_mask , ops .cast (sliding_mask , "bool" ))
199+
192200 def call (
193201 self ,
194202 x ,
@@ -216,7 +224,12 @@ def call(
216224 value = self .value_dense (x )
217225
218226 attention_vec = self ._compute_attention (
219- query , key , value , attention_mask , training = training
227+ query ,
228+ key ,
229+ value ,
230+ attention_mask ,
231+ training = training ,
232+ cache_update_index = cache_update_index ,
220233 )
221234
222235 # Wipe attn vec if there are no attended tokens.
0 commit comments