Skip to content

Commit bf684ad

Browse files
fix flash attention for mistral. (#758)
## fix flash attention for mistral. This pull request fixes flash attention forward method for mistral.
1 parent 0e18a53 commit bf684ad

File tree

1 file changed

+6
-15
lines changed

1 file changed

+6
-15
lines changed

src/adapters/models/mistral/modeling_mistral.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545

4646

4747
if is_flash_attn_2_available():
48-
from transformers.models.mistral.modeling_mistral import _flash_supports_window_size
48+
from transformers.modeling_flash_attention_utils import _flash_attention_forward
4949

5050

5151
logger = logging.get_logger(__name__)
@@ -173,18 +173,6 @@ def forward(
173173
cos, sin = self.rotary_emb(value_states, position_ids)
174174
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
175175

176-
use_sliding_windows = (
177-
_flash_supports_window_size
178-
and getattr(self.config, "sliding_window", None) is not None
179-
and kv_seq_len > self.config.sliding_window
180-
)
181-
182-
if not _flash_supports_window_size:
183-
logger.warning_once(
184-
"The current flash attention version does not support sliding window attention, for a more memory"
185-
" efficient implementation make sure to upgrade flash-attn library."
186-
)
187-
188176
if past_key_value is not None:
189177
# Activate slicing cache only if the config has a value `sliding_windows` attribute
190178
cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
@@ -257,14 +245,17 @@ def forward(
257245
key_states = key_states.transpose(1, 2)
258246
value_states = value_states.transpose(1, 2)
259247

260-
attn_output = self._flash_attention_forward(
248+
attn_output = _flash_attention_forward(
261249
query_states,
262250
key_states,
263251
value_states,
264252
attention_mask,
265253
q_len,
254+
position_ids=position_ids,
266255
dropout=dropout_rate,
267-
use_sliding_windows=use_sliding_windows,
256+
sliding_window=getattr(self.config, "sliding_window", None),
257+
use_top_left_mask=self._flash_attn_uses_top_left_mask,
258+
is_causal=self.is_causal,
268259
)
269260

270261
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()

0 commit comments

Comments
 (0)