1414 SPARGE_ATTN_AVAILABLE ,
1515)
1616
17+ FA3_MAX_HEADDIM = 256
18+
1719logger = logging .get_logger (__name__ )
1820
1921
@@ -130,31 +132,40 @@ def attention(
130132 "sage_attn" ,
131133 "sparge_attn" ,
132134 ]
135+ flash_attn3_compatible = q .shape [- 1 ] <= FA3_MAX_HEADDIM
133136 if attn_impl is None or attn_impl == "auto" :
134137 if FLASH_ATTN_3_AVAILABLE :
135- return flash_attn3 (q , k , v , softmax_scale = scale )
136- elif XFORMERS_AVAILABLE :
138+ if flash_attn3_compatible :
139+ return flash_attn3 (q , k , v , softmax_scale = scale )
140+ else :
141+ logger .warning (
142+ f"head_dim={ q .shape [- 1 ]} , but flash_attn_3 only supports head dimension at most { FA3_MAX_HEADDIM } , will use fallback attention implementation"
143+ )
144+ if XFORMERS_AVAILABLE :
137145 return xformers_attn (q , k , v , attn_mask = attn_mask , scale = scale )
138- elif SDPA_AVAILABLE :
146+ if SDPA_AVAILABLE :
139147 return sdpa_attn (q , k , v , attn_mask = attn_mask , scale = scale )
140- elif FLASH_ATTN_2_AVAILABLE :
148+ if FLASH_ATTN_2_AVAILABLE :
141149 return flash_attn2 (q , k , v , softmax_scale = scale )
142- else :
143- return eager_attn (q , k , v , attn_mask = attn_mask , scale = scale )
150+ return eager_attn (q , k , v , attn_mask = attn_mask , scale = scale )
144151 else :
145152 if attn_impl == "eager" :
146153 return eager_attn (q , k , v , attn_mask = attn_mask , scale = scale )
147- elif attn_impl == "flash_attn_3" :
154+ if attn_impl == "flash_attn_3" :
155+ if not flash_attn3_compatible :
156+ raise RuntimeError (
157+ f"head_dim={ q .shape [- 1 ]} , but flash_attn_3 only supports head dimension at most { FA3_MAX_HEADDIM } "
158+ )
148159 return flash_attn3 (q , k , v , softmax_scale = scale )
149- elif attn_impl == "flash_attn_2" :
160+ if attn_impl == "flash_attn_2" :
150161 return flash_attn2 (q , k , v , softmax_scale = scale )
151- elif attn_impl == "xformers" :
162+ if attn_impl == "xformers" :
152163 return xformers_attn (q , k , v , attn_mask = attn_mask , scale = scale )
153- elif attn_impl == "sdpa" :
164+ if attn_impl == "sdpa" :
154165 return sdpa_attn (q , k , v , attn_mask = attn_mask , scale = scale )
155- elif attn_impl == "sage_attn" :
166+ if attn_impl == "sage_attn" :
156167 return sage_attn (q , k , v , attn_mask = attn_mask , scale = scale )
157- elif attn_impl == "sparge_attn" :
168+ if attn_impl == "sparge_attn" :
158169 return sparge_attn (
159170 q ,
160171 k ,
@@ -166,8 +177,7 @@ def attention(
166177 cdfthreshd = kwargs .get ("sparge_cdfthreshd" , 0.98 ),
167178 pvthreshd = kwargs .get ("sparge_pvthreshd" , 50 ),
168179 )
169- else :
170- raise ValueError (f"Invalid attention implementation: { attn_impl } " )
180+ raise ValueError (f"Invalid attention implementation: { attn_impl } " )
171181
172182
173183class Attention (nn .Module ):
@@ -240,32 +250,42 @@ def long_context_attention(
240250 "sage_attn" ,
241251 "sparge_attn" ,
242252 ]
253+ flash_attn3_compatible = q .shape [- 1 ] <= FA3_MAX_HEADDIM
243254 if attn_impl is None or attn_impl == "auto" :
244255 if FLASH_ATTN_3_AVAILABLE :
245- attn_func = LongContextAttention (attn_type = AttnType .FA3 )
246- elif SDPA_AVAILABLE :
247- attn_func = LongContextAttention (attn_type = AttnType .TORCH )
248- elif FLASH_ATTN_2_AVAILABLE :
249- attn_func = LongContextAttention (attn_type = AttnType .FA )
250- else :
251- raise ValueError ("No available long context attention implementation" )
256+ if flash_attn3_compatible :
257+ return LongContextAttention (attn_type = AttnType .FA3 )(q , k , v , softmax_scale = scale )
258+ else :
259+ logger .warning (
260+ f"head_dim={ q .shape [- 1 ]} , but flash_attn_3 only supports head dimension at most { FA3_MAX_HEADDIM } , will use fallback attention implementation"
261+ )
262+ if SDPA_AVAILABLE :
263+ return LongContextAttention (attn_type = AttnType .TORCH )(q , k , v , softmax_scale = scale )
264+ if FLASH_ATTN_2_AVAILABLE :
265+ return LongContextAttention (attn_type = AttnType .FA )(q , k , v , softmax_scale = scale )
266+ raise ValueError ("No available long context attention implementation" )
252267 else :
253268 if attn_impl == "flash_attn_3" :
254- attn_func = LongContextAttention (attn_type = AttnType .FA3 )
255- elif attn_impl == "flash_attn_2" :
256- attn_func = LongContextAttention (attn_type = AttnType .FA )
257- elif attn_impl == "sdpa" :
258- attn_func = LongContextAttention (attn_type = AttnType .TORCH )
259- elif attn_impl == "sage_attn" :
260- attn_func = LongContextAttention (attn_type = AttnType .SAGE_FP8 )
261- elif attn_impl == "sparge_attn" :
269+ if flash_attn3_compatible :
270+ return LongContextAttention (attn_type = AttnType .FA3 )(q , k , v , softmax_scale = scale )
271+ else :
272+ raise RuntimeError (
273+ f"head_dim={ q .shape [- 1 ]} , but flash_attn_3 only supports head dimension at most { FA3_MAX_HEADDIM } "
274+ )
275+ if attn_impl == "flash_attn_2" :
276+ return LongContextAttention (attn_type = AttnType .FA )(q , k , v , softmax_scale = scale )
277+ if attn_impl == "sdpa" :
278+ return LongContextAttention (attn_type = AttnType .TORCH )(q , k , v , softmax_scale = scale )
279+ if attn_impl == "sage_attn" :
280+ return LongContextAttention (attn_type = AttnType .SAGE_FP8 )(q , k , v , softmax_scale = scale )
281+ if attn_impl == "sparge_attn" :
262282 attn_processor = SparseAttentionMeansim ()
263283 # default args from spas_sage2_attn_meansim_cuda
264284 attn_processor .smooth_k = torch .tensor (kwargs .get ("sparge_smooth_k" , True ))
265285 attn_processor .simthreshd1 = torch .tensor (kwargs .get ("sparge_simthreshd1" , 0.6 ))
266286 attn_processor .cdfthreshd = torch .tensor (kwargs .get ("sparge_cdfthreshd" , 0.98 ))
267287 attn_processor .pvthreshd = torch .tensor (kwargs .get ("sparge_pvthreshd" , 50 ))
268- attn_func = LongContextAttention (attn_type = AttnType .SPARSE_SAGE , attn_processor = attn_processor )
269- else :
270- raise ValueError ( f"Invalid long context attention implementation: { attn_impl } " )
271- return attn_func ( q , k , v , softmax_scale = scale )
288+ return LongContextAttention (attn_type = AttnType .SPARSE_SAGE , attn_processor = attn_processor )(
289+ q , k , v , softmax_scale = scale
290+ )
291+ raise ValueError ( f"Invalid long context attention implementation: { attn_impl } " )
0 commit comments