Skip to content

Commit b083e00

Browse files
fix: fix flash jina
1 parent 2f75be2 commit b083e00

File tree

1 file changed

+31
-31
lines changed

1 file changed

+31
-31
lines changed

backends/candle/src/flash_attn.rs

Lines changed: 31 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -38,39 +38,39 @@ pub(crate) fn flash_attn_varlen(
3838
}
3939
#[cfg(not(feature = "flash-attn-v1"))]
4040
candle::bail!("Flash attention v1 is not installed. Use `flash-attn-v1` feature.")
41-
} else if (80..90).contains(&runtime_compute_cap) {
41+
} else if (80..90).contains(&runtime_compute_cap) || runtime_compute_cap == 90 {
4242
#[cfg(feature = "flash-attn")]
4343
{
44-
use candle_flash_attn::flash_attn_varlen;
45-
return flash_attn_varlen(
46-
q,
47-
k,
48-
v,
49-
seqlens_q,
50-
seqlens_k,
51-
max_seqlen_q,
52-
max_seqlen_k,
53-
softmax_scale,
54-
causal,
55-
);
56-
}
57-
#[cfg(not(feature = "flash-attn"))]
58-
candle::bail!("Flash attention is not installed. Use `flash-attn-v1` feature.")
59-
} else if runtime_compute_cap == 90 {
60-
#[cfg(feature = "flash-attn")]
61-
{
62-
use candle_flash_attn::flash_attn_varlen;
63-
return flash_attn_varlen(
64-
q,
65-
k,
66-
v,
67-
seqlens_q,
68-
seqlens_k,
69-
max_seqlen_q,
70-
max_seqlen_k,
71-
softmax_scale,
72-
causal,
73-
);
44+
use candle_flash_attn::{flash_attn_varlen, flash_attn_varlen_alibi};
45+
46+
let attention = if let Some(alibi_slopes) = alibi_slopes {
47+
flash_attn_varlen_alibi(
48+
q,
49+
k,
50+
v,
51+
alibi_slopes,
52+
seqlens_q,
53+
seqlens_k,
54+
max_seqlen_q,
55+
max_seqlen_k,
56+
softmax_scale,
57+
causal,
58+
)
59+
} else {
60+
flash_attn_varlen(
61+
q,
62+
k,
63+
v,
64+
seqlens_q,
65+
seqlens_k,
66+
max_seqlen_q,
67+
max_seqlen_k,
68+
softmax_scale,
69+
causal,
70+
)
71+
};
72+
73+
return attention;
7474
}
7575
#[cfg(not(feature = "flash-attn"))]
7676
candle::bail!("Flash attention is not installed. Use `flash-attn-v1` feature.")

0 commit comments

Comments
 (0)