|
30 | 30 | raise e |
31 | 31 | exit(-1) |
32 | 32 |
|
| 33 | +SAGE_ATTENTION3_IS_AVAILABLE = False |
| 34 | +try: |
| 35 | + from sageattn3 import sageattn3_blackwell |
| 36 | + SAGE_ATTENTION3_IS_AVAILABLE = True |
| 37 | +except ImportError: |
| 38 | + pass |
| 39 | + |
33 | 40 | FLASH_ATTENTION_IS_AVAILABLE = False |
34 | 41 | try: |
35 | 42 | from flash_attn import flash_attn_func |
@@ -563,6 +570,93 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape= |
563 | 570 | out = out.reshape(b, -1, heads * dim_head) |
564 | 571 | return out |
565 | 572 |
|
| 573 | +@wrap_attn |
| 574 | +def attention3_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): |
| 575 | + exception_fallback = False |
| 576 | + if (q.device.type != "cuda" or |
| 577 | + q.dtype not in (torch.float16, torch.bfloat16) or |
| 578 | + mask is not None): |
| 579 | + return attention_pytorch( |
| 580 | + q, k, v, heads, |
| 581 | + mask=mask, |
| 582 | + attn_precision=attn_precision, |
| 583 | + skip_reshape=skip_reshape, |
| 584 | + skip_output_reshape=skip_output_reshape, |
| 585 | + **kwargs |
| 586 | + ) |
| 587 | + |
| 588 | + if skip_reshape: |
| 589 | + B, H, L, D = q.shape |
| 590 | + if H != heads: |
| 591 | + return attention_pytorch( |
| 592 | + q, k, v, heads, |
| 593 | + mask=mask, |
| 594 | + attn_precision=attn_precision, |
| 595 | + skip_reshape=True, |
| 596 | + skip_output_reshape=skip_output_reshape, |
| 597 | + **kwargs |
| 598 | + ) |
| 599 | + q_s, k_s, v_s = q, k, v |
| 600 | + N = q.shape[2] |
| 601 | + dim_head = D |
| 602 | + else: |
| 603 | + B, N, inner_dim = q.shape |
| 604 | + if inner_dim % heads != 0: |
| 605 | + return attention_pytorch( |
| 606 | + q, k, v, heads, |
| 607 | + mask=mask, |
| 608 | + attn_precision=attn_precision, |
| 609 | + skip_reshape=False, |
| 610 | + skip_output_reshape=skip_output_reshape, |
| 611 | + **kwargs |
| 612 | + ) |
| 613 | + dim_head = inner_dim // heads |
| 614 | + |
| 615 | + if dim_head >= 256 or N <= 1024: |
| 616 | + return attention_pytorch( |
| 617 | + q, k, v, heads, |
| 618 | + mask=mask, |
| 619 | + attn_precision=attn_precision, |
| 620 | + skip_reshape=skip_reshape, |
| 621 | + skip_output_reshape=skip_output_reshape, |
| 622 | + **kwargs |
| 623 | + ) |
| 624 | + |
| 625 | + if not skip_reshape: |
| 626 | + q_s, k_s, v_s = map( |
| 627 | + lambda t: t.view(B, -1, heads, dim_head).permute(0, 2, 1, 3).contiguous(), |
| 628 | + (q, k, v), |
| 629 | + ) |
| 630 | + B, H, L, D = q_s.shape |
| 631 | + |
| 632 | + try: |
| 633 | + out = sageattn3_blackwell(q_s, k_s, v_s, is_causal=False) |
| 634 | + except Exception as e: |
| 635 | + exception_fallback = True |
| 636 | + logging.error("Error running SageAttention3: %s, falling back to pytorch attention.", e) |
| 637 | + |
| 638 | + if exception_fallback: |
| 639 | + if not skip_reshape: |
| 640 | + del q_s, k_s, v_s |
| 641 | + return attention_pytorch( |
| 642 | + q, k, v, heads, |
| 643 | + mask=mask, |
| 644 | + attn_precision=attn_precision, |
| 645 | + skip_reshape=False, |
| 646 | + skip_output_reshape=skip_output_reshape, |
| 647 | + **kwargs |
| 648 | + ) |
| 649 | + |
| 650 | + if skip_reshape: |
| 651 | + if not skip_output_reshape: |
| 652 | + out = out.permute(0, 2, 1, 3).reshape(B, L, H * D) |
| 653 | + else: |
| 654 | + if skip_output_reshape: |
| 655 | + pass |
| 656 | + else: |
| 657 | + out = out.permute(0, 2, 1, 3).reshape(B, L, H * D) |
| 658 | + |
| 659 | + return out |
566 | 660 |
|
567 | 661 | try: |
568 | 662 | @torch.library.custom_op("flash_attention::flash_attn", mutates_args=()) |
@@ -650,6 +744,8 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape |
650 | 744 | # register core-supported attention functions |
651 | 745 | if SAGE_ATTENTION_IS_AVAILABLE: |
652 | 746 | register_attention_function("sage", attention_sage) |
| 747 | +if SAGE_ATTENTION3_IS_AVAILABLE: |
| 748 | + register_attention_function("sage3", attention3_sage) |
653 | 749 | if FLASH_ATTENTION_IS_AVAILABLE: |
654 | 750 | register_attention_function("flash", attention_flash) |
655 | 751 | if model_management.xformers_enabled(): |
|
0 commit comments