|
45 | 45 |
|
46 | 46 |
|
47 | 47 | if is_flash_attn_2_available(): |
48 | | - from transformers.models.mistral.modeling_mistral import _flash_supports_window_size |
| 48 | + from transformers.modeling_flash_attention_utils import _flash_attention_forward |
49 | 49 |
|
50 | 50 |
|
51 | 51 | logger = logging.get_logger(__name__) |
@@ -173,18 +173,6 @@ def forward( |
173 | 173 | cos, sin = self.rotary_emb(value_states, position_ids) |
174 | 174 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) |
175 | 175 |
|
176 | | - use_sliding_windows = ( |
177 | | - _flash_supports_window_size |
178 | | - and getattr(self.config, "sliding_window", None) is not None |
179 | | - and kv_seq_len > self.config.sliding_window |
180 | | - ) |
181 | | - |
182 | | - if not _flash_supports_window_size: |
183 | | - logger.warning_once( |
184 | | - "The current flash attention version does not support sliding window attention, for a more memory" |
185 | | - " efficient implementation make sure to upgrade flash-attn library." |
186 | | - ) |
187 | | - |
188 | 176 | if past_key_value is not None: |
189 | 177 | # Activate slicing cache only if the config has a value `sliding_windows` attribute |
190 | 178 | cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 |
@@ -257,14 +245,17 @@ def forward( |
257 | 245 | key_states = key_states.transpose(1, 2) |
258 | 246 | value_states = value_states.transpose(1, 2) |
259 | 247 |
|
260 | | - attn_output = self._flash_attention_forward( |
| 248 | + attn_output = _flash_attention_forward( |
261 | 249 | query_states, |
262 | 250 | key_states, |
263 | 251 | value_states, |
264 | 252 | attention_mask, |
265 | 253 | q_len, |
| 254 | + position_ids=position_ids, |
266 | 255 | dropout=dropout_rate, |
267 | | - use_sliding_windows=use_sliding_windows, |
| 256 | + sliding_window=getattr(self.config, "sliding_window", None), |
| 257 | + use_top_left_mask=self._flash_attn_uses_top_left_mask, |
| 258 | + is_causal=self.is_causal, |
268 | 259 | ) |
269 | 260 |
|
270 | 261 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() |
|
0 commit comments