Skip to content

Commit 6b2825e

Browse files
authored
Enable aiter attention for rocm (#198)
* added aiter support for attention backend rebased to main * fixed bug in attention.py * adjusted formatting * added aiter attention to long_context_attention attn_impl=auto * modified name of aiter attention impl * modified name for aiter attention * modified naming in config pipeline
1 parent 1f969ae commit 6b2825e

File tree

3 files changed

+59
-0
lines changed

3 files changed

+59
-0
lines changed

diffsynth_engine/configs/pipeline.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ class AttnImpl(Enum):
2626
FA2 = "fa2" # Flash Attention 2
2727
FA3 = "fa3" # Flash Attention 3
2828
FA3_FP8 = "fa3_fp8" # Flash Attention 3 with FP8
29+
AITER = "aiter" # Aiter Flash Attention
30+
AITER_FP8 = "aiter_fp8" # Aiter Flash Attention with FP8
2931
XFORMERS = "xformers" # XFormers
3032
SDPA = "sdpa" # Scaled Dot Product Attention
3133
SAGE = "sage" # Sage Attention

diffsynth_engine/models/basic/attention.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
SAGE_ATTN_AVAILABLE,
1414
SPARGE_ATTN_AVAILABLE,
1515
VIDEO_SPARSE_ATTN_AVAILABLE,
16+
AITER_AVAILABLE,
1617
)
1718
from diffsynth_engine.utils.platform import DTYPE_FP8
1819

@@ -93,6 +94,9 @@ def sparge_attn(
9394
)
9495
return out.transpose(1, 2)
9596

97+
if AITER_AVAILABLE:
98+
from aiter import flash_attn_func as aiter_flash_attn
99+
from aiter import flash_attn_fp8_pertensor_func as aiter_flash_attn_fp8
96100

97101
if VIDEO_SPARSE_ATTN_AVAILABLE:
98102
from diffsynth_engine.models.basic.video_sparse_attention import (
@@ -137,6 +141,8 @@ def attention(
137141
"fa2",
138142
"fa3",
139143
"fa3_fp8",
144+
"aiter",
145+
"aiter_fp8",
140146
"xformers",
141147
"sdpa",
142148
"sage",
@@ -157,6 +163,13 @@ def attention(
157163
logger.debug(
158164
"flash_attn_3 does not support attention mask, will use fallback attention implementation"
159165
)
166+
if AITER_AVAILABLE:
167+
if flash_attn3_compatible:
168+
return aiter_flash_attn(q, k, v, softmax_scale=scale)
169+
else:
170+
logger.warning(
171+
f"head_dim={q.shape[-1]}, but aiter_flash_attn only supports head dimension at most {FA3_MAX_HEADDIM}, will use fallback attention implementation"
172+
)
160173
if XFORMERS_AVAILABLE:
161174
return xformers_attn(q, k, v, attn_mask=attn_mask, scale=scale)
162175
if SDPA_AVAILABLE:
@@ -183,6 +196,22 @@ def attention(
183196
v = v.to(dtype=DTYPE_FP8)
184197
out = flash_attn3(q, k, v, softmax_scale=scale)
185198
return out.to(dtype=origin_dtype)
199+
if attn_impl == "aiter" or attn_impl == "aiter_fp8":
200+
if not flash_attn3_compatible:
201+
raise RuntimeError(
202+
f"head_dim={q.shape[-1]}, but aiter_flash_attn only supports head dimension at most {FA3_MAX_HEADDIM}"
203+
)
204+
if attn_mask is not None:
205+
raise RuntimeError("aiter_flash_attn does not support attention mask")
206+
if attn_impl == "aiter" :
207+
return aiter_flash_attn(q, k, v, softmax_scale=scale)
208+
else:
209+
origin_dtype = q.dtype
210+
q = q.to(dtype=DTYPE_FP8)
211+
k = k.to(dtype=DTYPE_FP8)
212+
v = v.to(dtype=DTYPE_FP8)
213+
out = aiter_flash_attn_fp8(q, k, v, softmax_scale=scale)
214+
return out.to(dtype=origin_dtype)
186215
if attn_impl == "fa2":
187216
return flash_attn2(q, k, v, softmax_scale=scale)
188217
if attn_impl == "xformers":
@@ -288,6 +317,8 @@ def long_context_attention(
288317
"fa2",
289318
"fa3",
290319
"fa3_fp8",
320+
"aiter",
321+
"aiter_fp8",
291322
"sdpa",
292323
"sage",
293324
"sparge",
@@ -303,6 +334,13 @@ def long_context_attention(
303334
logger.warning(
304335
f"head_dim={q.shape[-1]}, but flash_attn_3 only supports head dimension at most {FA3_MAX_HEADDIM}, will use fallback attention implementation"
305336
)
337+
if AITER_AVAILABLE:
338+
if flash_attn3_compatible:
339+
return LongContextAttention(attn_type=AttnType.AITER)(q, k, v, softmax_scale=scale)
340+
else:
341+
logger.warning(
342+
f"head_dim={q.shape[-1]}, but aiter_flash_attn only supports head dimension at most {FA3_MAX_HEADDIM}, will use fallback attention implementation"
343+
)
306344
if SDPA_AVAILABLE:
307345
return LongContextAttention(attn_type=AttnType.TORCH)(q, k, v, softmax_scale=scale)
308346
if FLASH_ATTN_2_AVAILABLE:
@@ -323,6 +361,20 @@ def long_context_attention(
323361
v = v.to(dtype=DTYPE_FP8)
324362
out = LongContextAttention(attn_type=AttnType.FA3)(q, k, v, softmax_scale=scale)
325363
return out.to(dtype=origin_dtype)
364+
if attn_impl == "aiter" or attn_impl == "aiter_fp8":
365+
if not flash_attn3_compatible:
366+
raise RuntimeError(
367+
f"head_dim={q.shape[-1]}, but aiter_flash_attn only supports head dimension at most {FA3_MAX_HEADDIM}"
368+
)
369+
if attn_impl == "aiter":
370+
return LongContextAttention(attn_type=AttnType.AITER)(q, k, v, softmax_scale=scale)
371+
372+
origin_dtype = q.dtype
373+
q = q.to(dtype=DTYPE_FP8)
374+
k = k.to(dtype=DTYPE_FP8)
375+
v = v.to(dtype=DTYPE_FP8)
376+
out = LongContextAttention(attn_type=AttnType.AITER)(q, k, v, softmax_scale=scale)
377+
return out.to(dtype=origin_dtype)
326378
if attn_impl == "fa2":
327379
return LongContextAttention(attn_type=AttnType.FA)(q, k, v, softmax_scale=scale)
328380
if attn_impl == "sdpa":

diffsynth_engine/utils/flag.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,11 @@
3131
else:
3232
logger.info("Torch SDPA is not available")
3333

34+
AITER_AVAILABLE = importlib.util.find_spec("aiter") is not None
35+
if AITER_AVAILABLE:
36+
logger.info("Aiter is available")
37+
else:
38+
logger.info("Aiter is not available")
3439

3540
# 有损
3641
SAGE_ATTN_AVAILABLE = importlib.util.find_spec("sageattention") is not None

0 commit comments

Comments
 (0)