11import torch
22import torch .nn as nn
33from einops import rearrange
4-
4+ from typing import Optional
55from diffsynth_engine .utils import logging
6+ from diffsynth_engine .utils .flag import (
7+ FLASH_ATTN_3_AVAILABLE ,
8+ FLASH_ATTN_2_AVAILABLE ,
9+ XFORMERS_AVAILABLE ,
10+ SDPA_AVAILABLE ,
11+ SAGE_ATTN_AVAILABLE ,
12+ SPARGE_ATTN_AVAILABLE ,
13+ )
14+
15+ if FLASH_ATTN_3_AVAILABLE :
16+ from flash_attn_interface import flash_attn_func as flash_attn3
17+ if FLASH_ATTN_2_AVAILABLE :
18+ from flash_attn import flash_attn_func as flash_attn2
19+ if XFORMERS_AVAILABLE :
20+ import xformers .ops .memory_efficient_attention as xformers_attn
21+ if SDPA_AVAILABLE :
22+
23+ def sdpa_attn (q , k , v , attn_mask = None , scale = None ):
24+ q = q .transpose (1 , 2 )
25+ k = k .transpose (1 , 2 )
26+ v = v .transpose (1 , 2 )
27+ out = torch .nn .functional .scaled_dot_product_attention (q , k , v , attn_mask = attn_mask , scale = scale )
28+ return out .transpose (1 , 2 )
29+
30+
31+ if SAGE_ATTN_AVAILABLE :
32+ from sageattention import sageattn
33+
34+ def sage_attn (q , k , v , attn_mask = None , scale = None ):
35+ q = q .transpose (1 , 2 )
36+ k = k .transpose (1 , 2 )
37+ v = v .transpose (1 , 2 )
38+ out = sageattn (q , k , v , attn_mask = attn_mask , sm_scale = scale )
39+ return out .transpose (1 , 2 )
40+
41+
42+ if SPARGE_ATTN_AVAILABLE :
43+ from spas_sage_attn import spas_sage2_attn_meansim_cuda
44+
45+ def sparge_attn (self , q , k , v , attn_mask = None , scale = None ):
46+ q = q .transpose (1 , 2 )
47+ k = k .transpose (1 , 2 )
48+ v = v .transpose (1 , 2 )
49+ out = spas_sage2_attn_meansim_cuda (q , k , v , attn_mask = attn_mask , scale = scale )
50+ return out .transpose (1 , 2 )
51+
652
753logger = logging .get_logger (__name__ )
854
955
56+ def eager_attn (query , key , value , attn_mask = None , scale = None ):
57+ scale = 1 / query .shape [- 1 ] ** 0.5 if scale is None else scale
58+ query = query * scale
59+ attn = torch .matmul (query , key .transpose (- 2 , - 1 ))
60+ if attn_mask is not None :
61+ attn = attn + attn_mask
62+ attn = attn .softmax (- 1 )
63+ return attn @ value
64+
65+
66+ def attention (q , k , v , attn_mask = None , attn_impl : Optional [str ] = None , scale : Optional [float ] = None ):
67+ """
68+ q: [B, Lq, Nq, C1]
69+ k: [B, Lk, Nk, C1]
70+ v: [B, Lk, Nk, C2]
71+ """
72+ assert attn_impl in [
73+ None ,
74+ "auto" ,
75+ "eager" ,
76+ "flash_attn_2" ,
77+ "flash_attn_3" ,
78+ "xformers" ,
79+ "sdpa" ,
80+ "sage_attn" ,
81+ "sparge_attn" ,
82+ ]
83+ if attn_impl is None or attn_impl == "auto" :
84+ if FLASH_ATTN_3_AVAILABLE :
85+ return flash_attn3 (q , k , v , softmax_scale = scale )
86+ elif FLASH_ATTN_2_AVAILABLE :
87+ return flash_attn2 (q , k , v , softmax_scale = scale )
88+ elif XFORMERS_AVAILABLE :
89+ return xformers_attn (q , k , v , attn_bias = attn_mask , scale = scale )
90+ elif SDPA_AVAILABLE :
91+ return sdpa_attn (q , k , v , attn_mask = attn_mask , scale = scale )
92+ else :
93+ return eager_attn (q , k , v , attn_mask = attn_mask , scale = scale )
94+ else :
95+ if attn_impl == "eager" :
96+ return eager_attn (q , k , v , attn_mask = attn_mask , scale = scale )
97+ elif attn_impl == "flash_attn_3" :
98+ return flash_attn3 (q , k , v , softmax_scale = scale )
99+ elif attn_impl == "flash_attn_2" :
100+ return flash_attn2 (q , k , v , softmax_scale = scale )
101+ elif attn_impl == "xformers" :
102+ return xformers_attn (q , k , v , attn_bias = attn_mask , scale = scale )
103+ elif attn_impl == "sdpa" :
104+ return sdpa_attn (q , k , v , attn_mask = attn_mask , scale = scale )
105+ elif attn_impl == "sage_attn" :
106+ return sage_attn (q , k , v , attn_mask = attn_mask , scale = scale )
107+ elif attn_impl == "sparge_attn" :
108+ return sparge_attn (q , k , v , attn_mask = attn_mask , scale = scale )
109+ else :
110+ raise ValueError (f"Invalid attention implementation: { attn_impl } " )
111+
112+
10113class Attention (nn .Module ):
11114 def __init__ (
12115 self ,
@@ -18,7 +121,7 @@ def __init__(
18121 bias_kv = False ,
19122 bias_out = False ,
20123 scale = None ,
21- attn_implementation : str = "sdpa" ,
124+ attn_impl : Optional [ str ] = None ,
22125 device : str = "cuda:0" ,
23126 dtype : torch .dtype = torch .float16 ,
24127 ):
@@ -32,106 +135,20 @@ def __init__(
32135 self .to_k = nn .Linear (kv_dim , dim_inner , bias = bias_kv , device = device , dtype = dtype )
33136 self .to_v = nn .Linear (kv_dim , dim_inner , bias = bias_kv , device = device , dtype = dtype )
34137 self .to_out = nn .Linear (dim_inner , q_dim , bias = bias_out , device = device , dtype = dtype )
35-
138+ self . attn_impl = attn_impl
36139 self .scale = scale
37- self .attn_implementation = self ._get_actual_attn_implementation (attn_implementation )
38-
39- @staticmethod
40- def _get_actual_attn_implementation (attn_implementation ):
41- supported_implementations = ("xformers" , "sdpa" , "eager" )
42- if attn_implementation not in supported_implementations :
43- raise ValueError (
44- f"attn_implementation must be one of { supported_implementations } , but got '{ attn_implementation } '"
45- )
46-
47- actual_implementation = "eager" if attn_implementation == "eager" else ""
48- if attn_implementation == "xformers" :
49- try :
50- from xformers .ops import memory_efficient_attention
51-
52- actual_implementation = "xformers"
53- except ImportError :
54- pass
55- if not actual_implementation or attn_implementation == "sdpa" :
56- use_mps = torch .backends .mps .is_available ()
57- if hasattr (torch .nn .functional , "scaled_dot_product_attention" ) and not use_mps :
58- actual_implementation = "sdpa"
59-
60- if actual_implementation != attn_implementation :
61- warning_msg = (
62- "xformers is not supported on this platform"
63- if attn_implementation == "xformers"
64- else "torch.nn.functional.scaled_dot_product_attention is not supported"
65- )
66- logger .warning (f"{ warning_msg } , fallback to '{ actual_implementation } ' attention" )
67- return actual_implementation
68-
69- def sdpa_attn (self , hidden_states , encoder_hidden_states , attn_mask = None ):
70- q = self .to_q (hidden_states )
71- k = self .to_k (encoder_hidden_states )
72- v = self .to_v (encoder_hidden_states )
73-
74- q = rearrange (q , "b s (n d) -> b n s d" , n = self .num_heads )
75- k = rearrange (k , "b s (n d) -> b n s d" , n = self .num_heads )
76- v = rearrange (v , "b s (n d) -> b n s d" , n = self .num_heads )
77-
78- hidden_states = nn .functional .scaled_dot_product_attention (q , k , v , attn_mask = attn_mask , scale = self .scale )
79- hidden_states = rearrange (hidden_states , "b n s d -> b s (n d)" , n = self .num_heads )
80- hidden_states = hidden_states .to (q .dtype )
81- hidden_states = self .to_out (hidden_states )
82- return hidden_states
83-
84- def xformers_attn (self , hidden_states , encoder_hidden_states , attn_mask = None ):
85- import xformers .ops as xops
86-
87- q = self .to_q (hidden_states )
88- k = self .to_k (encoder_hidden_states )
89- v = self .to_v (encoder_hidden_states )
90- q = rearrange (q , "b s (n d) -> b s n d" , n = self .num_heads )
91- k = rearrange (k , "b s (n d) -> b s n d" , n = self .num_heads )
92- v = rearrange (v , "b s (n d) -> b s n d" , n = self .num_heads )
93-
94- hidden_states = xops .memory_efficient_attention (q , k , v , attn_bias = attn_mask , scale = self .scale )
95- hidden_states = rearrange (hidden_states , "b s n d -> b s (n d)" )
96- hidden_states = hidden_states .to (q .dtype )
97- hidden_states = self .to_out (hidden_states )
98- return hidden_states
99-
100- def eager_attn (self , hidden_states , encoder_hidden_states , attn_mask = None ):
101- q = self .to_q (hidden_states )
102- k = self .to_k (encoder_hidden_states )
103- v = self .to_v (encoder_hidden_states )
104- q = rearrange (q , "b s (n d) -> b n s d" , n = self .num_heads )
105- k = rearrange (k , "b s (n d) -> b n s d" , n = self .num_heads )
106- v = rearrange (v , "b s (n d) -> b n s d" , n = self .num_heads )
107-
108- hidden_states = self ._eager_attn (q , k , v , attn_bias = attn_mask , scale = self .scale )
109- hidden_states = rearrange (hidden_states , "b n s d -> b s (n d)" , n = self .num_heads )
110- hidden_states = hidden_states .to (q .dtype )
111- hidden_states = self .to_out (hidden_states )
112- return hidden_states
113-
114- @staticmethod
115- def _eager_attn (query , key , value , attn_bias = None , scale = None ):
116- scale = 1 / query .shape [- 1 ] ** 0.5 if scale is None else scale
117- query = query * scale
118- attn = torch .matmul (query , key .transpose (- 2 , - 1 ))
119- if attn_bias is not None :
120- attn = attn + attn_bias
121- attn = attn .softmax (- 1 )
122- return attn @ value
123140
124141 def forward (
125142 self ,
126- hidden_states ,
127- encoder_hidden_states = None ,
128- attn_mask = None ,
143+ x : torch . Tensor ,
144+ y : Optional [ torch . Tensor ] = None ,
145+ attn_mask : Optional [ torch . Tensor ] = None ,
129146 ):
130- if encoder_hidden_states is None :
131- encoder_hidden_states = hidden_states
132-
133- if self .attn_implementation == "xformers" :
134- return self .xformers_attn ( hidden_states , encoder_hidden_states , attn_mask )
135- if self . attn_implementation == "sdpa" :
136- return self . sdpa_attn ( hidden_states , encoder_hidden_states , attn_mask )
137- return self .eager_attn ( hidden_states , encoder_hidden_states , attn_mask )
147+ if y is None :
148+ y = x
149+ q = rearrange ( self . to_q ( x ), "b s (n d) -> b s n d" , n = self . num_heads )
150+ k = rearrange ( self .to_k ( y ), "b s (n d) -> b s n d" , n = self . num_heads )
151+ v = rearrange ( self .to_v ( y ), "b s (n d) -> b s n d" , n = self . num_heads )
152+ out = attention ( q , k , v , attn_mask = attn_mask , attn_impl = self . attn_impl , scale = self . scale )
153+ out = rearrange ( out , "b s n d -> b s (n d)" , n = self . num_heads )
154+ return self .to_out ( out )
0 commit comments