Skip to content

Commit 45ab89b

Browse files
qzzz95akaitsuki-ii
andauthored
defend flash attention3 failed (#126)
* defend flash attention3 failed * rename * add FA3_MAX_HEADDIM --------- Co-authored-by: zhuguoxuan.zgx <[email protected]>
1 parent c4ded2c commit 45ab89b

File tree

1 file changed

+53
-33
lines changed

1 file changed

+53
-33
lines changed

diffsynth_engine/models/basic/attention.py

Lines changed: 53 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
SPARGE_ATTN_AVAILABLE,
1515
)
1616

17+
FA3_MAX_HEADDIM = 256
18+
1719
logger = logging.get_logger(__name__)
1820

1921

@@ -130,31 +132,40 @@ def attention(
130132
"sage_attn",
131133
"sparge_attn",
132134
]
135+
flash_attn3_compatible = q.shape[-1] <= FA3_MAX_HEADDIM
133136
if attn_impl is None or attn_impl == "auto":
134137
if FLASH_ATTN_3_AVAILABLE:
135-
return flash_attn3(q, k, v, softmax_scale=scale)
136-
elif XFORMERS_AVAILABLE:
138+
if flash_attn3_compatible:
139+
return flash_attn3(q, k, v, softmax_scale=scale)
140+
else:
141+
logger.warning(
142+
f"head_dim={q.shape[-1]}, but flash_attn_3 only supports head dimension at most {FA3_MAX_HEADDIM}, will use fallback attention implementation"
143+
)
144+
if XFORMERS_AVAILABLE:
137145
return xformers_attn(q, k, v, attn_mask=attn_mask, scale=scale)
138-
elif SDPA_AVAILABLE:
146+
if SDPA_AVAILABLE:
139147
return sdpa_attn(q, k, v, attn_mask=attn_mask, scale=scale)
140-
elif FLASH_ATTN_2_AVAILABLE:
148+
if FLASH_ATTN_2_AVAILABLE:
141149
return flash_attn2(q, k, v, softmax_scale=scale)
142-
else:
143-
return eager_attn(q, k, v, attn_mask=attn_mask, scale=scale)
150+
return eager_attn(q, k, v, attn_mask=attn_mask, scale=scale)
144151
else:
145152
if attn_impl == "eager":
146153
return eager_attn(q, k, v, attn_mask=attn_mask, scale=scale)
147-
elif attn_impl == "flash_attn_3":
154+
if attn_impl == "flash_attn_3":
155+
if not flash_attn3_compatible:
156+
raise RuntimeError(
157+
f"head_dim={q.shape[-1]}, but flash_attn_3 only supports head dimension at most {FA3_MAX_HEADDIM}"
158+
)
148159
return flash_attn3(q, k, v, softmax_scale=scale)
149-
elif attn_impl == "flash_attn_2":
160+
if attn_impl == "flash_attn_2":
150161
return flash_attn2(q, k, v, softmax_scale=scale)
151-
elif attn_impl == "xformers":
162+
if attn_impl == "xformers":
152163
return xformers_attn(q, k, v, attn_mask=attn_mask, scale=scale)
153-
elif attn_impl == "sdpa":
164+
if attn_impl == "sdpa":
154165
return sdpa_attn(q, k, v, attn_mask=attn_mask, scale=scale)
155-
elif attn_impl == "sage_attn":
166+
if attn_impl == "sage_attn":
156167
return sage_attn(q, k, v, attn_mask=attn_mask, scale=scale)
157-
elif attn_impl == "sparge_attn":
168+
if attn_impl == "sparge_attn":
158169
return sparge_attn(
159170
q,
160171
k,
@@ -166,8 +177,7 @@ def attention(
166177
cdfthreshd=kwargs.get("sparge_cdfthreshd", 0.98),
167178
pvthreshd=kwargs.get("sparge_pvthreshd", 50),
168179
)
169-
else:
170-
raise ValueError(f"Invalid attention implementation: {attn_impl}")
180+
raise ValueError(f"Invalid attention implementation: {attn_impl}")
171181

172182

173183
class Attention(nn.Module):
@@ -240,32 +250,42 @@ def long_context_attention(
240250
"sage_attn",
241251
"sparge_attn",
242252
]
253+
flash_attn3_compatible = q.shape[-1] <= FA3_MAX_HEADDIM
243254
if attn_impl is None or attn_impl == "auto":
244255
if FLASH_ATTN_3_AVAILABLE:
245-
attn_func = LongContextAttention(attn_type=AttnType.FA3)
246-
elif SDPA_AVAILABLE:
247-
attn_func = LongContextAttention(attn_type=AttnType.TORCH)
248-
elif FLASH_ATTN_2_AVAILABLE:
249-
attn_func = LongContextAttention(attn_type=AttnType.FA)
250-
else:
251-
raise ValueError("No available long context attention implementation")
256+
if flash_attn3_compatible:
257+
return LongContextAttention(attn_type=AttnType.FA3)(q, k, v, softmax_scale=scale)
258+
else:
259+
logger.warning(
260+
f"head_dim={q.shape[-1]}, but flash_attn_3 only supports head dimension at most {FA3_MAX_HEADDIM}, will use fallback attention implementation"
261+
)
262+
if SDPA_AVAILABLE:
263+
return LongContextAttention(attn_type=AttnType.TORCH)(q, k, v, softmax_scale=scale)
264+
if FLASH_ATTN_2_AVAILABLE:
265+
return LongContextAttention(attn_type=AttnType.FA)(q, k, v, softmax_scale=scale)
266+
raise ValueError("No available long context attention implementation")
252267
else:
253268
if attn_impl == "flash_attn_3":
254-
attn_func = LongContextAttention(attn_type=AttnType.FA3)
255-
elif attn_impl == "flash_attn_2":
256-
attn_func = LongContextAttention(attn_type=AttnType.FA)
257-
elif attn_impl == "sdpa":
258-
attn_func = LongContextAttention(attn_type=AttnType.TORCH)
259-
elif attn_impl == "sage_attn":
260-
attn_func = LongContextAttention(attn_type=AttnType.SAGE_FP8)
261-
elif attn_impl == "sparge_attn":
269+
if flash_attn3_compatible:
270+
return LongContextAttention(attn_type=AttnType.FA3)(q, k, v, softmax_scale=scale)
271+
else:
272+
raise RuntimeError(
273+
f"head_dim={q.shape[-1]}, but flash_attn_3 only supports head dimension at most {FA3_MAX_HEADDIM}"
274+
)
275+
if attn_impl == "flash_attn_2":
276+
return LongContextAttention(attn_type=AttnType.FA)(q, k, v, softmax_scale=scale)
277+
if attn_impl == "sdpa":
278+
return LongContextAttention(attn_type=AttnType.TORCH)(q, k, v, softmax_scale=scale)
279+
if attn_impl == "sage_attn":
280+
return LongContextAttention(attn_type=AttnType.SAGE_FP8)(q, k, v, softmax_scale=scale)
281+
if attn_impl == "sparge_attn":
262282
attn_processor = SparseAttentionMeansim()
263283
# default args from spas_sage2_attn_meansim_cuda
264284
attn_processor.smooth_k = torch.tensor(kwargs.get("sparge_smooth_k", True))
265285
attn_processor.simthreshd1 = torch.tensor(kwargs.get("sparge_simthreshd1", 0.6))
266286
attn_processor.cdfthreshd = torch.tensor(kwargs.get("sparge_cdfthreshd", 0.98))
267287
attn_processor.pvthreshd = torch.tensor(kwargs.get("sparge_pvthreshd", 50))
268-
attn_func = LongContextAttention(attn_type=AttnType.SPARSE_SAGE, attn_processor=attn_processor)
269-
else:
270-
raise ValueError(f"Invalid long context attention implementation: {attn_impl}")
271-
return attn_func(q, k, v, softmax_scale=scale)
288+
return LongContextAttention(attn_type=AttnType.SPARSE_SAGE, attn_processor=attn_processor)(
289+
q, k, v, softmax_scale=scale
290+
)
291+
raise ValueError(f"Invalid long context attention implementation: {attn_impl}")

0 commit comments

Comments
 (0)