@@ -822,19 +822,27 @@ def __init__(
822822 self .to_out = nn .Linear (inner_dim * (1 + len (kernel_sizes )), out_channels , bias = False )
823823 self .norm_out = get_normalization (norm_type , num_features = out_channels )
824824
825- self .processor = SanaMultiscaleLinearAttnProcessor2_0 ()
826- self .processor_quadratic = SanaMultiscaleQuadraticAttnProcessor2_0 ()
825+ self .processor = SanaMultiscaleAttnProcessor2_0 ()
827826
828- def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
829- height , width = hidden_states .shape [- 2 :]
827+ def apply_linear_attention (self , query : torch .Tensor , key : torch .Tensor , value : torch .Tensor ) -> torch .Tensor :
828+ value = F .pad (value , (0 , 0 , 0 , 1 ), mode = "constant" , value = 1 ) # Adds padding
829+ scores = torch .matmul (value , key .transpose (- 1 , - 2 ))
830+ hidden_states = torch .matmul (scores , query )
830831
831- if height * width > self .attention_head_dim :
832- hidden_states = self .processor (self , hidden_states )
833- else :
834- hidden_states = self .processor_quadratic (self , hidden_states )
832+ hidden_states = hidden_states .to (dtype = torch .float32 )
833+ hidden_states = hidden_states [:, :, :- 1 ] / (hidden_states [:, :, - 1 :] + self .eps )
834+ return hidden_states
835835
836+ def apply_quadratic_attention (self , query : torch .Tensor , key : torch .Tensor , value : torch .Tensor ) -> torch .Tensor :
837+ scores = torch .matmul (key .transpose (- 1 , - 2 ), query )
838+ scores = scores .to (dtype = torch .float32 )
839+ scores = scores / (torch .sum (scores , dim = 2 , keepdim = True ) + self .eps )
840+ hidden_states = torch .matmul (value , scores )
836841 return hidden_states
837842
843+ def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
844+ return self .processor (self , hidden_states )
845+
838846
839847class AttnProcessor :
840848 r"""
@@ -5089,65 +5097,18 @@ def __call__(
50895097 return hidden_states
50905098
50915099
5092- class SanaMultiscaleLinearAttnProcessor2_0 :
5100+ class SanaMultiscaleAttnProcessor2_0 :
50935101 r"""
5094- Processor for implementing multiscale linear attention.
5102+ Processor for implementing multiscale quadratic attention.
50955103 """
50965104
50975105 def __call__ (self , attn : SanaMultiscaleLinearAttention , hidden_states : torch .Tensor ) -> torch .Tensor :
5098- residual = hidden_states
5099-
5100- batch_size , _ , height , width = hidden_states .shape
5101- original_dtype = hidden_states .dtype
5102-
5103- hidden_states = hidden_states .movedim (1 , - 1 )
5104- query = attn .to_q (hidden_states )
5105- key = attn .to_k (hidden_states )
5106- value = attn .to_v (hidden_states )
5107- hidden_states = torch .cat ([query , key , value ], dim = 3 )
5108- hidden_states = hidden_states .movedim (- 1 , 1 )
5109-
5110- multiscale_hidden_states = [hidden_states ]
5111- for block in attn .to_qkv_multiscale :
5112- multiscale_hidden_states .append (block (hidden_states ))
5113-
5114- hidden_states = torch .cat (multiscale_hidden_states , dim = 1 )
5115-
5116- hidden_states = hidden_states .to (dtype = torch .float32 )
5117- hidden_states = hidden_states .reshape (batch_size , - 1 , 3 * attn .attention_head_dim , height * width )
5118-
5119- query , key , value = hidden_states .chunk (3 , dim = 2 )
5120- query = attn .nonlinearity (query )
5121- key = attn .nonlinearity (key )
5122- value = F .pad (value , (0 , 0 , 0 , 1 ), mode = "constant" , value = 1 )
5123-
5124- scores = torch .matmul (value , key .transpose (- 1 , - 2 ))
5125- hidden_states = torch .matmul (scores , query )
5126-
5127- hidden_states = hidden_states .to (dtype = torch .float32 )
5128- hidden_states = hidden_states [:, :, :- 1 ] / (hidden_states [:, :, - 1 :] + attn .eps )
5129- hidden_states = hidden_states .to (dtype = original_dtype )
5130-
5131- hidden_states = torch .reshape (hidden_states , (batch_size , - 1 , height , width ))
5132- hidden_states = attn .to_out (hidden_states .movedim (1 , - 1 )).movedim (- 1 , 1 )
5133-
5134- if attn .norm_type == "rms_norm" :
5135- hidden_states = attn .norm_out (hidden_states .movedim (1 , - 1 )).movedim (- 1 , 1 )
5106+ height , width = hidden_states .shape [- 2 :]
5107+ if height * width > attn .attention_head_dim :
5108+ use_linear_attention = True
51365109 else :
5137- hidden_states = attn . norm_out ( hidden_states )
5110+ use_linear_attention = False
51385111
5139- if attn .residual_connection :
5140- hidden_states = hidden_states + residual
5141-
5142- return hidden_states
5143-
5144-
5145- class SanaMultiscaleQuadraticAttnProcessor2_0 :
5146- r"""
5147- Processor for implementing multiscale quadratic attention.
5148- """
5149-
5150- def __call__ (self , attn : SanaMultiscaleLinearAttention , hidden_states : torch .Tensor ) -> torch .Tensor :
51515112 residual = hidden_states
51525113
51535114 batch_size , _ , height , width = list (hidden_states .size ())
@@ -5166,17 +5127,21 @@ def __call__(self, attn: SanaMultiscaleLinearAttention, hidden_states: torch.Ten
51665127
51675128 hidden_states = torch .cat (multi_scale_qkv , dim = 1 )
51685129
5130+ if use_linear_attention :
5131+ # for linear attention upcast hidden_states to float32
5132+ hidden_states = hidden_states .to (dtype = torch .float32 )
5133+
51695134 hidden_states = hidden_states .reshape (batch_size , - 1 , 3 * attn .attention_head_dim , height * width )
51705135
51715136 query , key , value = hidden_states .chunk (3 , dim = 2 )
51725137 query = attn .nonlinearity (query )
51735138 key = attn .nonlinearity (key )
51745139
5175- scores = torch . matmul ( key . transpose ( - 1 , - 2 ), query )
5176- scores = scores . to ( dtype = torch . float32 )
5177- scores = scores / ( torch . sum ( scores , dim = 2 , keepdim = True ) + attn . eps )
5178- scores = scores . to ( dtype = original_dtype )
5179- hidden_states = torch . matmul ( value , scores )
5140+ if use_linear_attention :
5141+ hidden_states = attn . apply_linear_attention ( query , key , value )
5142+ hidden_states = hidden_states . to ( dtype = original_dtype )
5143+ else :
5144+ hidden_states = attn . apply_quadratic_attention ( query , key , value )
51805145
51815146 hidden_states = torch .reshape (hidden_states , (batch_size , - 1 , height , width ))
51825147 hidden_states = attn .to_out (hidden_states .movedim (1 , - 1 )).movedim (- 1 , 1 )
0 commit comments