File tree Expand file tree Collapse file tree 1 file changed +31
-31
lines changed Expand file tree Collapse file tree 1 file changed +31
-31
lines changed Original file line number Diff line number Diff line change @@ -38,39 +38,39 @@ pub(crate) fn flash_attn_varlen(
38
38
}
39
39
#[ cfg( not( feature = "flash-attn-v1" ) ) ]
40
40
candle:: bail!( "Flash attention v1 is not installed. Use `flash-attn-v1` feature." )
41
- } else if ( 80 ..90 ) . contains ( & runtime_compute_cap) {
41
+ } else if ( 80 ..90 ) . contains ( & runtime_compute_cap) || runtime_compute_cap == 90 {
42
42
#[ cfg( feature = "flash-attn" ) ]
43
43
{
44
- use candle_flash_attn:: flash_attn_varlen;
45
- return flash_attn_varlen (
46
- q ,
47
- k ,
48
- v ,
49
- seqlens_q ,
50
- seqlens_k ,
51
- max_seqlen_q ,
52
- max_seqlen_k ,
53
- softmax_scale ,
54
- causal ,
55
- ) ;
56
- }
57
- # [ cfg ( not ( feature = "flash-attn" ) ) ]
58
- candle :: bail! ( "Flash attention is not installed. Use `flash-attn-v1` feature." )
59
- } else if runtime_compute_cap == 90 {
60
- # [ cfg ( feature = "flash-attn" ) ]
61
- {
62
- use candle_flash_attn :: flash_attn_varlen ;
63
- return flash_attn_varlen (
64
- q ,
65
- k ,
66
- v ,
67
- seqlens_q ,
68
- seqlens_k ,
69
- max_seqlen_q ,
70
- max_seqlen_k ,
71
- softmax_scale ,
72
- causal ,
73
- ) ;
44
+ use candle_flash_attn:: { flash_attn_varlen, flash_attn_varlen_alibi } ;
45
+
46
+ let attention = if let Some ( alibi_slopes ) = alibi_slopes {
47
+ flash_attn_varlen_alibi (
48
+ q ,
49
+ k ,
50
+ v ,
51
+ alibi_slopes ,
52
+ seqlens_q ,
53
+ seqlens_k ,
54
+ max_seqlen_q ,
55
+ max_seqlen_k ,
56
+ softmax_scale ,
57
+ causal ,
58
+ )
59
+ } else {
60
+ flash_attn_varlen (
61
+ q ,
62
+ k ,
63
+ v ,
64
+ seqlens_q ,
65
+ seqlens_k ,
66
+ max_seqlen_q ,
67
+ max_seqlen_k ,
68
+ softmax_scale ,
69
+ causal ,
70
+ )
71
+ } ;
72
+
73
+ return attention ;
74
74
}
75
75
#[ cfg( not( feature = "flash-attn" ) ) ]
76
76
candle:: bail!( "Flash attention is not installed. Use `flash-attn-v1` feature." )
You can’t perform that action at this time.
0 commit comments