Skip to content

Commit 0357ed7

Browse files
authored
Add support for sage attention 3 in comfyui, enable via new cli arg (#11026)
* Add support for sage attention 3 in comfyui, enable via new cli arg --use-sage-attiention3 * Fix some bugs found in PR review. The N dimension at which Sage Attention 3 takes effect is reduced to 1024 (although the improvement is not significant at this scale). * Remove the Sage Attention3 switch, but retain the attention function registration. * Fix a ruff check issue in attention.py
1 parent f59f71c commit 0357ed7

File tree

1 file changed

+96
-0
lines changed

1 file changed

+96
-0
lines changed

comfy/ldm/modules/attention.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,13 @@
3030
raise e
3131
exit(-1)
3232

33+
SAGE_ATTENTION3_IS_AVAILABLE = False
34+
try:
35+
from sageattn3 import sageattn3_blackwell
36+
SAGE_ATTENTION3_IS_AVAILABLE = True
37+
except ImportError:
38+
pass
39+
3340
FLASH_ATTENTION_IS_AVAILABLE = False
3441
try:
3542
from flash_attn import flash_attn_func
@@ -563,6 +570,93 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
563570
out = out.reshape(b, -1, heads * dim_head)
564571
return out
565572

573+
@wrap_attn
574+
def attention3_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
575+
exception_fallback = False
576+
if (q.device.type != "cuda" or
577+
q.dtype not in (torch.float16, torch.bfloat16) or
578+
mask is not None):
579+
return attention_pytorch(
580+
q, k, v, heads,
581+
mask=mask,
582+
attn_precision=attn_precision,
583+
skip_reshape=skip_reshape,
584+
skip_output_reshape=skip_output_reshape,
585+
**kwargs
586+
)
587+
588+
if skip_reshape:
589+
B, H, L, D = q.shape
590+
if H != heads:
591+
return attention_pytorch(
592+
q, k, v, heads,
593+
mask=mask,
594+
attn_precision=attn_precision,
595+
skip_reshape=True,
596+
skip_output_reshape=skip_output_reshape,
597+
**kwargs
598+
)
599+
q_s, k_s, v_s = q, k, v
600+
N = q.shape[2]
601+
dim_head = D
602+
else:
603+
B, N, inner_dim = q.shape
604+
if inner_dim % heads != 0:
605+
return attention_pytorch(
606+
q, k, v, heads,
607+
mask=mask,
608+
attn_precision=attn_precision,
609+
skip_reshape=False,
610+
skip_output_reshape=skip_output_reshape,
611+
**kwargs
612+
)
613+
dim_head = inner_dim // heads
614+
615+
if dim_head >= 256 or N <= 1024:
616+
return attention_pytorch(
617+
q, k, v, heads,
618+
mask=mask,
619+
attn_precision=attn_precision,
620+
skip_reshape=skip_reshape,
621+
skip_output_reshape=skip_output_reshape,
622+
**kwargs
623+
)
624+
625+
if not skip_reshape:
626+
q_s, k_s, v_s = map(
627+
lambda t: t.view(B, -1, heads, dim_head).permute(0, 2, 1, 3).contiguous(),
628+
(q, k, v),
629+
)
630+
B, H, L, D = q_s.shape
631+
632+
try:
633+
out = sageattn3_blackwell(q_s, k_s, v_s, is_causal=False)
634+
except Exception as e:
635+
exception_fallback = True
636+
logging.error("Error running SageAttention3: %s, falling back to pytorch attention.", e)
637+
638+
if exception_fallback:
639+
if not skip_reshape:
640+
del q_s, k_s, v_s
641+
return attention_pytorch(
642+
q, k, v, heads,
643+
mask=mask,
644+
attn_precision=attn_precision,
645+
skip_reshape=False,
646+
skip_output_reshape=skip_output_reshape,
647+
**kwargs
648+
)
649+
650+
if skip_reshape:
651+
if not skip_output_reshape:
652+
out = out.permute(0, 2, 1, 3).reshape(B, L, H * D)
653+
else:
654+
if skip_output_reshape:
655+
pass
656+
else:
657+
out = out.permute(0, 2, 1, 3).reshape(B, L, H * D)
658+
659+
return out
566660

567661
try:
568662
@torch.library.custom_op("flash_attention::flash_attn", mutates_args=())
@@ -650,6 +744,8 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
650744
# register core-supported attention functions
651745
if SAGE_ATTENTION_IS_AVAILABLE:
652746
register_attention_function("sage", attention_sage)
747+
if SAGE_ATTENTION3_IS_AVAILABLE:
748+
register_attention_function("sage3", attention3_sage)
653749
if FLASH_ATTENTION_IS_AVAILABLE:
654750
register_attention_function("flash", attention_flash)
655751
if model_management.xformers_enabled():

0 commit comments

Comments
 (0)