Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions diffsynth_engine/configs/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class AttnImpl(Enum):
FA2 = "fa2" # Flash Attention 2
FA3 = "fa3" # Flash Attention 3
FA3_FP8 = "fa3_fp8" # Flash Attention 3 with FP8
FA4 = "fa4" # Flash Attention 4
AITER = "aiter" # Aiter Flash Attention
AITER_FP8 = "aiter_fp8" # Aiter Flash Attention with FP8
XFORMERS = "xformers" # XFormers
Expand Down
32 changes: 31 additions & 1 deletion diffsynth_engine/models/basic/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from diffsynth_engine.utils import logging
from diffsynth_engine.utils.flag import (
FLASH_ATTN_4_AVAILABLE,
FLASH_ATTN_3_AVAILABLE,
FLASH_ATTN_2_AVAILABLE,
XFORMERS_AVAILABLE,
Expand All @@ -21,7 +22,8 @@

logger = logging.get_logger(__name__)


if FLASH_ATTN_4_AVAILABLE:
from flash_attn.cute.interface import flash_attn_func as flash_attn4
if FLASH_ATTN_3_AVAILABLE:
from flash_attn_interface import flash_attn_func as flash_attn3
if FLASH_ATTN_2_AVAILABLE:
Expand Down Expand Up @@ -142,6 +144,7 @@ def attention(
"fa2",
"fa3",
"fa3_fp8",
"fa4",
"aiter",
"aiter_fp8",
"xformers",
Expand All @@ -152,6 +155,22 @@ def attention(
]
flash_attn3_compatible = q.shape[-1] <= FA3_MAX_HEADDIM
if attn_impl is None or attn_impl == "auto":
if FLASH_ATTN_4_AVAILABLE:
# FA4 also has the same max-head-256 limitation as FA3
if flash_attn3_compatible and attn_mask is None:
attn_out = flash_attn4(q, k, v, softmax_scale=scale)
if isinstance(attn_out, tuple):
attn_out = attn_out[0]
return attn_out
else:
if not flash_attn3_compatible:
logger.warning(
f"head_dim={q.shape[-1]}, but flash_attn_4 only supports head dimension at most {FA3_MAX_HEADDIM}, will use fallback attention implementation"
)
else:
logger.debug(
"flash_attn_4 does not support attention mask, will use fallback attention implementation"
)
if FLASH_ATTN_3_AVAILABLE:
if flash_attn3_compatible and attn_mask is None:
return flash_attn3(q, k, v, softmax_scale=scale)
Expand Down Expand Up @@ -213,6 +232,17 @@ def attention(
v = v.to(dtype=DTYPE_FP8)
out = aiter_flash_attn_fp8(q, k, v, softmax_scale=scale)
return out.to(dtype=origin_dtype)
if attn_impl == "fa4":
if not flash_attn3_compatible:
raise RuntimeError(
f"head_dim={q.shape[-1]}, but flash_attn_4 only supports head dimension at most {FA3_MAX_HEADDIM}"
)
if attn_mask is not None:
raise RuntimeError("flash_attn_4 does not support attention mask")
attn_out = flash_attn4(q, k, v, softmax_scale=scale)
if isinstance(attn_out, tuple):
attn_out = attn_out[0]
return attn_out
Comment on lines +242 to +245
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block of code for calling flash_attn4 and processing its output is identical to the one in lines 161-164. Duplicating code can lead to maintenance issues where a bug fix or change in one place is not applied to the other. It would be best to refactor this into a helper function to avoid repetition.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm following the existing code pattern and don't want to introduce unnecessary refactoring.

if attn_impl == "fa2":
return flash_attn2(q, k, v, softmax_scale=scale)
if attn_impl == "xformers":
Expand Down
6 changes: 6 additions & 0 deletions diffsynth_engine/utils/flag.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@


# 无损
FLASH_ATTN_4_AVAILABLE = importlib.util.find_spec("flash_attn.cute.interface") is not None
if FLASH_ATTN_4_AVAILABLE:
logger.info("Flash attention 4 is available")
else:
logger.info("Flash attention 4 is not available")
Comment on lines +11 to +14
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This if/else block for logging availability is a common pattern in this file. You can simplify this and reduce duplication with a single line of code using an f-string and a ternary operator.

logger.info(f"Flash attention 4 {'is' if FLASH_ATTN_4_AVAILABLE else 'is not'} available")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm following the existing code pattern and don't want to introduce unnecessary refactoring.


FLASH_ATTN_3_AVAILABLE = importlib.util.find_spec("flash_attn_interface") is not None
if FLASH_ATTN_3_AVAILABLE:
logger.info("Flash attention 3 is available")
Expand Down