diff --git a/diffsynth_engine/configs/pipeline.py b/diffsynth_engine/configs/pipeline.py index 0188e57..a997314 100644 --- a/diffsynth_engine/configs/pipeline.py +++ b/diffsynth_engine/configs/pipeline.py @@ -26,6 +26,7 @@ class AttnImpl(Enum): FA2 = "fa2" # Flash Attention 2 FA3 = "fa3" # Flash Attention 3 FA3_FP8 = "fa3_fp8" # Flash Attention 3 with FP8 + FA4 = "fa4" # Flash Attention 4 AITER = "aiter" # Aiter Flash Attention AITER_FP8 = "aiter_fp8" # Aiter Flash Attention with FP8 XFORMERS = "xformers" # XFormers diff --git a/diffsynth_engine/models/basic/attention.py b/diffsynth_engine/models/basic/attention.py index 5d98ae1..3d3d49c 100644 --- a/diffsynth_engine/models/basic/attention.py +++ b/diffsynth_engine/models/basic/attention.py @@ -6,6 +6,7 @@ from diffsynth_engine.utils import logging from diffsynth_engine.utils.flag import ( + FLASH_ATTN_4_AVAILABLE, FLASH_ATTN_3_AVAILABLE, FLASH_ATTN_2_AVAILABLE, XFORMERS_AVAILABLE, @@ -21,7 +22,8 @@ logger = logging.get_logger(__name__) - +if FLASH_ATTN_4_AVAILABLE: + from flash_attn.cute.interface import flash_attn_func as flash_attn4 if FLASH_ATTN_3_AVAILABLE: from flash_attn_interface import flash_attn_func as flash_attn3 if FLASH_ATTN_2_AVAILABLE: @@ -142,6 +144,7 @@ def attention( "fa2", "fa3", "fa3_fp8", + "fa4", "aiter", "aiter_fp8", "xformers", @@ -152,6 +155,22 @@ def attention( ] flash_attn3_compatible = q.shape[-1] <= FA3_MAX_HEADDIM if attn_impl is None or attn_impl == "auto": + if FLASH_ATTN_4_AVAILABLE: + # FA4 also has the same max-head-256 limitation as FA3 + if flash_attn3_compatible and attn_mask is None: + attn_out = flash_attn4(q, k, v, softmax_scale=scale) + if isinstance(attn_out, tuple): + attn_out = attn_out[0] + return attn_out + else: + if not flash_attn3_compatible: + logger.warning( + f"head_dim={q.shape[-1]}, but flash_attn_4 only supports head dimension at most {FA3_MAX_HEADDIM}, will use fallback attention implementation" + ) + else: + logger.debug( + "flash_attn_4 does not support attention mask, will use fallback attention implementation" + ) if FLASH_ATTN_3_AVAILABLE: if flash_attn3_compatible and attn_mask is None: return flash_attn3(q, k, v, softmax_scale=scale) @@ -213,6 +232,17 @@ def attention( v = v.to(dtype=DTYPE_FP8) out = aiter_flash_attn_fp8(q, k, v, softmax_scale=scale) return out.to(dtype=origin_dtype) + if attn_impl == "fa4": + if not flash_attn3_compatible: + raise RuntimeError( + f"head_dim={q.shape[-1]}, but flash_attn_4 only supports head dimension at most {FA3_MAX_HEADDIM}" + ) + if attn_mask is not None: + raise RuntimeError("flash_attn_4 does not support attention mask") + attn_out = flash_attn4(q, k, v, softmax_scale=scale) + if isinstance(attn_out, tuple): + attn_out = attn_out[0] + return attn_out if attn_impl == "fa2": return flash_attn2(q, k, v, softmax_scale=scale) if attn_impl == "xformers": diff --git a/diffsynth_engine/utils/flag.py b/diffsynth_engine/utils/flag.py index ff03949..84a42d5 100644 --- a/diffsynth_engine/utils/flag.py +++ b/diffsynth_engine/utils/flag.py @@ -7,6 +7,12 @@ # 无损 +FLASH_ATTN_4_AVAILABLE = importlib.util.find_spec("flash_attn.cute.interface") is not None +if FLASH_ATTN_4_AVAILABLE: + logger.info("Flash attention 4 is available") +else: + logger.info("Flash attention 4 is not available") + FLASH_ATTN_3_AVAILABLE = importlib.util.find_spec("flash_attn_interface") is not None if FLASH_ATTN_3_AVAILABLE: logger.info("Flash attention 3 is available")