1313 SAGE_ATTN_AVAILABLE ,
1414 SPARGE_ATTN_AVAILABLE ,
1515 VIDEO_SPARSE_ATTN_AVAILABLE ,
16+ AITER_AVAILABLE ,
1617)
1718from diffsynth_engine .utils .platform import DTYPE_FP8
1819
@@ -93,6 +94,9 @@ def sparge_attn(
9394 )
9495 return out .transpose (1 , 2 )
9596
97+ if AITER_AVAILABLE :
98+ from aiter import flash_attn_func as aiter_flash_attn
99+ from aiter import flash_attn_fp8_pertensor_func as aiter_flash_attn_fp8
96100
97101if VIDEO_SPARSE_ATTN_AVAILABLE :
98102 from diffsynth_engine .models .basic .video_sparse_attention import (
@@ -137,6 +141,8 @@ def attention(
137141 "fa2" ,
138142 "fa3" ,
139143 "fa3_fp8" ,
144+ "aiter" ,
145+ "aiter_fp8" ,
140146 "xformers" ,
141147 "sdpa" ,
142148 "sage" ,
@@ -157,6 +163,13 @@ def attention(
157163 logger .debug (
158164 "flash_attn_3 does not support attention mask, will use fallback attention implementation"
159165 )
166+ if AITER_AVAILABLE :
167+ if flash_attn3_compatible :
168+ return aiter_flash_attn (q , k , v , softmax_scale = scale )
169+ else :
170+ logger .warning (
171+ f"head_dim={ q .shape [- 1 ]} , but aiter_flash_attn only supports head dimension at most { FA3_MAX_HEADDIM } , will use fallback attention implementation"
172+ )
160173 if XFORMERS_AVAILABLE :
161174 return xformers_attn (q , k , v , attn_mask = attn_mask , scale = scale )
162175 if SDPA_AVAILABLE :
@@ -183,6 +196,22 @@ def attention(
183196 v = v .to (dtype = DTYPE_FP8 )
184197 out = flash_attn3 (q , k , v , softmax_scale = scale )
185198 return out .to (dtype = origin_dtype )
199+ if attn_impl == "aiter" or attn_impl == "aiter_fp8" :
200+ if not flash_attn3_compatible :
201+ raise RuntimeError (
202+ f"head_dim={ q .shape [- 1 ]} , but aiter_flash_attn only supports head dimension at most { FA3_MAX_HEADDIM } "
203+ )
204+ if attn_mask is not None :
205+ raise RuntimeError ("aiter_flash_attn does not support attention mask" )
206+ if attn_impl == "aiter" :
207+ return aiter_flash_attn (q , k , v , softmax_scale = scale )
208+ else :
209+ origin_dtype = q .dtype
210+ q = q .to (dtype = DTYPE_FP8 )
211+ k = k .to (dtype = DTYPE_FP8 )
212+ v = v .to (dtype = DTYPE_FP8 )
213+ out = aiter_flash_attn_fp8 (q , k , v , softmax_scale = scale )
214+ return out .to (dtype = origin_dtype )
186215 if attn_impl == "fa2" :
187216 return flash_attn2 (q , k , v , softmax_scale = scale )
188217 if attn_impl == "xformers" :
@@ -288,6 +317,8 @@ def long_context_attention(
288317 "fa2" ,
289318 "fa3" ,
290319 "fa3_fp8" ,
320+ "aiter" ,
321+ "aiter_fp8" ,
291322 "sdpa" ,
292323 "sage" ,
293324 "sparge" ,
@@ -303,6 +334,13 @@ def long_context_attention(
303334 logger .warning (
304335 f"head_dim={ q .shape [- 1 ]} , but flash_attn_3 only supports head dimension at most { FA3_MAX_HEADDIM } , will use fallback attention implementation"
305336 )
337+ if AITER_AVAILABLE :
338+ if flash_attn3_compatible :
339+ return LongContextAttention (attn_type = AttnType .AITER )(q , k , v , softmax_scale = scale )
340+ else :
341+ logger .warning (
342+ f"head_dim={ q .shape [- 1 ]} , but aiter_flash_attn only supports head dimension at most { FA3_MAX_HEADDIM } , will use fallback attention implementation"
343+ )
306344 if SDPA_AVAILABLE :
307345 return LongContextAttention (attn_type = AttnType .TORCH )(q , k , v , softmax_scale = scale )
308346 if FLASH_ATTN_2_AVAILABLE :
@@ -323,6 +361,20 @@ def long_context_attention(
323361 v = v .to (dtype = DTYPE_FP8 )
324362 out = LongContextAttention (attn_type = AttnType .FA3 )(q , k , v , softmax_scale = scale )
325363 return out .to (dtype = origin_dtype )
364+ if attn_impl == "aiter" or attn_impl == "aiter_fp8" :
365+ if not flash_attn3_compatible :
366+ raise RuntimeError (
367+ f"head_dim={ q .shape [- 1 ]} , but aiter_flash_attn only supports head dimension at most { FA3_MAX_HEADDIM } "
368+ )
369+ if attn_impl == "aiter" :
370+ return LongContextAttention (attn_type = AttnType .AITER )(q , k , v , softmax_scale = scale )
371+
372+ origin_dtype = q .dtype
373+ q = q .to (dtype = DTYPE_FP8 )
374+ k = k .to (dtype = DTYPE_FP8 )
375+ v = v .to (dtype = DTYPE_FP8 )
376+ out = LongContextAttention (attn_type = AttnType .AITER )(q , k , v , softmax_scale = scale )
377+ return out .to (dtype = origin_dtype )
326378 if attn_impl == "fa2" :
327379 return LongContextAttention (attn_type = AttnType .FA )(q , k , v , softmax_scale = scale )
328380 if attn_impl == "sdpa" :
0 commit comments