Skip to content

Commit fbf26b7

Browse files
committed
Fix flash_attn3 in dispatch attn backend by _flash_attn_forward, replace its origin implement; Add DocString in pipeline for that.
1 parent 71e8049 commit fbf26b7

File tree

2 files changed

+33
-6
lines changed

2 files changed

+33
-6
lines changed

src/diffusers/models/attention_dispatch.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,11 @@
7979
if _CAN_USE_FLASH_ATTN_3:
8080
from flash_attn_interface import flash_attn_func as flash_attn_3_func
8181
from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
82+
from flash_attn_interface import _flash_attn_forward as flash_attn_3_forward
8283
else:
8384
flash_attn_3_func = None
8485
flash_attn_3_varlen_func = None
86+
flash_attn_3_forward = None
8587

8688
if _CAN_USE_AITER_ATTN:
8789
from aiter import flash_attn_func as aiter_flash_attn_func
@@ -621,22 +623,42 @@ def _wrapped_flash_attn_3(
621623
) -> Tuple[torch.Tensor, torch.Tensor]:
622624
# Hardcoded for now because pytorch does not support tuple/int type hints
623625
window_size = (-1, -1)
624-
out, lse, *_ = flash_attn_3_func(
626+
max_seqlen_q = q.shape[2]
627+
max_seqlen_k = k.shape[2]
628+
629+
out, lse, *_ = flash_attn_3_forward(
625630
q=q,
626631
k=k,
627632
v=v,
628-
softmax_scale=softmax_scale,
629-
causal=causal,
633+
k_new=None,
634+
v_new=None,
630635
qv=qv,
636+
out=None,
637+
cu_seqlens_q=None,
638+
cu_seqlens_k=None,
639+
cu_seqlens_k_new=None,
640+
seqused_q=None,
641+
seqused_k=None,
642+
max_seqlen_q=max_seqlen_q,
643+
max_seqlen_k=max_seqlen_k,
644+
page_table=None,
645+
kv_batch_idx=None,
646+
leftpad_k=None,
647+
rotary_cos=None,
648+
rotary_sin=None,
649+
seqlens_rotary=None,
631650
q_descale=q_descale,
632651
k_descale=k_descale,
633652
v_descale=v_descale,
653+
softmax_scale=softmax_scale,
654+
causal=causal,
634655
window_size=window_size,
635656
attention_chunk=attention_chunk,
636657
softcap=softcap,
658+
rotary_interleaved=True,
659+
scheduler_metadata=None,
637660
num_splits=num_splits,
638661
pack_gqa=pack_gqa,
639-
deterministic=deterministic,
640662
sm_margin=sm_margin,
641663
)
642664
lse = lse.permute(0, 2, 1)

src/diffusers/pipelines/z_image/pipeline_z_image.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,14 @@
3939
4040
>>> pipe = ZImagePipeline.from_pretrained("Z-a-o/Z-Image-Turbo", torch_dtype=torch.bfloat16)
4141
>>> pipe.to("cuda")
42+
43+
>>> # Optionally, set the attention backend to flash-attn 2 or 3, default is SDPA in PyTorch.
44+
>>> # (1) Use flash attention 2
45+
>>> # pipe.transformer.set_attention_backend("flash")
46+
>>> # (2) Use flash attention 3
47+
>>> # pipe.transformer.set_attention_backend("_flash_3")
48+
4249
>>> prompt = "一幅为名为“造相「Z-IMAGE-TURBO」”的项目设计的创意海报。画面巧妙地将文字概念视觉化:一辆复古蒸汽小火车化身为巨大的拉链头,正拉开厚厚的冬日积雪,展露出一个生机盎然的春天。"
43-
>>> # Depending on the variant being used, the pipeline call will slightly vary.
44-
>>> # Refer to the pipeline documentation for more details.
4550
>>> image = pipe(
4651
... prompt,
4752
... height=1024,

0 commit comments

Comments
 (0)