Skip to content

Commit 28685dd

Browse files
committed
Revert "Fix flash_attn3 in dispatch attn backend by _flash_attn_forward, replace its origin implement; Add DocString in pipeline for that."
This reverts commit fbf26b7.
1 parent fbf26b7 commit 28685dd

File tree

2 files changed

+6
-33
lines changed

2 files changed

+6
-33
lines changed

src/diffusers/models/attention_dispatch.py

Lines changed: 4 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -79,11 +79,9 @@
7979
if _CAN_USE_FLASH_ATTN_3:
8080
from flash_attn_interface import flash_attn_func as flash_attn_3_func
8181
from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
82-
from flash_attn_interface import _flash_attn_forward as flash_attn_3_forward
8382
else:
8483
flash_attn_3_func = None
8584
flash_attn_3_varlen_func = None
86-
flash_attn_3_forward = None
8785

8886
if _CAN_USE_AITER_ATTN:
8987
from aiter import flash_attn_func as aiter_flash_attn_func
@@ -623,42 +621,22 @@ def _wrapped_flash_attn_3(
623621
) -> Tuple[torch.Tensor, torch.Tensor]:
624622
# Hardcoded for now because pytorch does not support tuple/int type hints
625623
window_size = (-1, -1)
626-
max_seqlen_q = q.shape[2]
627-
max_seqlen_k = k.shape[2]
628-
629-
out, lse, *_ = flash_attn_3_forward(
624+
out, lse, *_ = flash_attn_3_func(
630625
q=q,
631626
k=k,
632627
v=v,
633-
k_new=None,
634-
v_new=None,
628+
softmax_scale=softmax_scale,
629+
causal=causal,
635630
qv=qv,
636-
out=None,
637-
cu_seqlens_q=None,
638-
cu_seqlens_k=None,
639-
cu_seqlens_k_new=None,
640-
seqused_q=None,
641-
seqused_k=None,
642-
max_seqlen_q=max_seqlen_q,
643-
max_seqlen_k=max_seqlen_k,
644-
page_table=None,
645-
kv_batch_idx=None,
646-
leftpad_k=None,
647-
rotary_cos=None,
648-
rotary_sin=None,
649-
seqlens_rotary=None,
650631
q_descale=q_descale,
651632
k_descale=k_descale,
652633
v_descale=v_descale,
653-
softmax_scale=softmax_scale,
654-
causal=causal,
655634
window_size=window_size,
656635
attention_chunk=attention_chunk,
657636
softcap=softcap,
658-
rotary_interleaved=True,
659-
scheduler_metadata=None,
660637
num_splits=num_splits,
661638
pack_gqa=pack_gqa,
639+
deterministic=deterministic,
662640
sm_margin=sm_margin,
663641
)
664642
lse = lse.permute(0, 2, 1)

src/diffusers/pipelines/z_image/pipeline_z_image.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,9 @@
3939
4040
>>> pipe = ZImagePipeline.from_pretrained("Z-a-o/Z-Image-Turbo", torch_dtype=torch.bfloat16)
4141
>>> pipe.to("cuda")
42-
43-
>>> # Optionally, set the attention backend to flash-attn 2 or 3, default is SDPA in PyTorch.
44-
>>> # (1) Use flash attention 2
45-
>>> # pipe.transformer.set_attention_backend("flash")
46-
>>> # (2) Use flash attention 3
47-
>>> # pipe.transformer.set_attention_backend("_flash_3")
48-
4942
>>> prompt = "一幅为名为“造相「Z-IMAGE-TURBO」”的项目设计的创意海报。画面巧妙地将文字概念视觉化:一辆复古蒸汽小火车化身为巨大的拉链头,正拉开厚厚的冬日积雪,展露出一个生机盎然的春天。"
43+
>>> # Depending on the variant being used, the pipeline call will slightly vary.
44+
>>> # Refer to the pipeline documentation for more details.
5045
>>> image = pipe(
5146
... prompt,
5247
... height=1024,

0 commit comments

Comments
 (0)