@@ -133,6 +133,13 @@ def _compute_attention(
133133 query_normalization = 1 / np .sqrt (
134134 self .hidden_dim // self .num_query_heads
135135 )
136+
137+ if self .use_sliding_window_attention and attention_mask is not None :
138+ attention_mask = self ._mask_sliding_window (
139+ attention_mask ,
140+ cache_update_index = cache_update_index ,
141+ )
142+
136143 if self ._can_use_flash_attention ():
137144 if attention_mask is not None :
138145 attention_mask = ops .expand_dims (attention_mask , axis = 1 )
@@ -172,13 +179,8 @@ def _compute_attention(
172179 ops .tanh (attention_logits ), self .logit_soft_cap
173180 )
174181
175- if self .use_sliding_window_attention :
176- attention_mask = self ._mask_sliding_window (
177- attention_mask ,
178- cache_update_index = cache_update_index ,
179- )
180-
181- attention_mask = attention_mask [:, None , None , :, :]
182+ if attention_mask is not None :
183+ attention_mask = attention_mask [:, None , None , :, :]
182184 orig_dtype = attention_logits .dtype
183185 attention_softmax = self .softmax (attention_logits , mask = attention_mask )
184186 attention_softmax = ops .cast (attention_softmax , orig_dtype )
0 commit comments