@@ -752,7 +752,7 @@ def fuse_projections(self, fuse=True):
752752 self .fused_projections = fuse
753753
754754
755- class MultiscaleAttentionProjection (nn .Module ):
755+ class SanaMultiscaleAttentionProjection (nn .Module ):
756756 def __init__ (
757757 self ,
758758 in_channels : int ,
@@ -761,25 +761,24 @@ def __init__(
761761 ) -> None :
762762 super ().__init__ ()
763763
764+ channels = 3 * in_channels
764765 self .proj_in = nn .Conv2d (
765- 3 * in_channels ,
766- 3 * in_channels ,
766+ channels ,
767+ channels ,
767768 kernel_size ,
768769 padding = kernel_size // 2 ,
769770 groups = 3 * in_channels ,
770771 bias = False ,
771772 )
772- self .proj_out = nn .Conv2d (
773- 3 * in_channels , 3 * in_channels , 1 , 1 , 0 , groups = 3 * num_attention_heads , bias = False
774- )
773+ self .proj_out = nn .Conv2d (channels , channels , 1 , 1 , 0 , groups = 3 * num_attention_heads , bias = False )
775774
776775 def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
777776 hidden_states = self .proj_in (hidden_states )
778777 hidden_states = self .proj_out (hidden_states )
779778 return hidden_states
780779
781780
782- class MultiscaleLinearAttention (nn .Module ):
781+ class SanaMultiscaleLinearAttention (nn .Module ):
783782 r"""Lightweight multi-scale linear attention"""
784783
785784 def __init__ (
@@ -792,6 +791,7 @@ def __init__(
792791 norm_type : str = "batch_norm" ,
793792 kernel_sizes : Tuple [int , ...] = (5 ,),
794793 eps : float = 1e-15 ,
794+ residual_connection : bool = False ,
795795 ):
796796 super ().__init__ ()
797797
@@ -801,6 +801,7 @@ def __init__(
801801 self .eps = eps
802802 self .attention_head_dim = attention_head_dim
803803 self .norm_type = norm_type
804+ self .residual_connection = residual_connection
804805
805806 num_attention_heads = (
806807 int (in_channels // attention_head_dim * heads_ratio )
@@ -809,102 +810,32 @@ def __init__(
809810 )
810811 inner_dim = num_attention_heads * attention_head_dim
811812
812- # self.to_qkv = nn.Conv2d(in_channels, 3 * inner_dim, 1, 1, 0, bias=False)
813813 self .to_q = nn .Linear (in_channels , inner_dim , bias = False )
814814 self .to_k = nn .Linear (in_channels , inner_dim , bias = False )
815815 self .to_v = nn .Linear (in_channels , inner_dim , bias = False )
816816
817817 self .to_qkv_multiscale = nn .ModuleList ()
818818 for kernel_size in kernel_sizes :
819- self .to_qkv_multiscale .append (MultiscaleAttentionProjection (inner_dim , num_attention_heads , kernel_size ))
819+ self .to_qkv_multiscale .append (
820+ SanaMultiscaleAttentionProjection (inner_dim , num_attention_heads , kernel_size )
821+ )
820822
821- self .kernel_nonlinearity = nn .ReLU ()
822- self .proj_out = nn .Conv2d (inner_dim * (1 + len (kernel_sizes )), out_channels , 1 , 1 , 0 , bias = False )
823+ self .nonlinearity = nn .ReLU ()
824+ self .to_out = nn .Linear (inner_dim * (1 + len (kernel_sizes )), out_channels , bias = False )
823825 self .norm_out = get_normalization (norm_type , num_features = out_channels )
824826
825- def linear_attention (self , qkv : torch .Tensor ) -> torch .Tensor :
826- batch_size , _ , height , width = qkv .shape
827-
828- qkv = qkv .float ()
829- qkv = torch .reshape (qkv , (batch_size , - 1 , 3 * self .attention_head_dim , height * width ))
830-
831- query , key , value = (
832- qkv [:, :, 0 : self .attention_head_dim ],
833- qkv [:, :, self .attention_head_dim : 2 * self .attention_head_dim ],
834- qkv [:, :, 2 * self .attention_head_dim :],
835- )
836-
837- # lightweight linear attention
838- query = self .kernel_nonlinearity (query )
839- key = self .kernel_nonlinearity (key )
840- value = F .pad (value , (0 , 0 , 0 , 1 ), mode = "constant" , value = 1 )
841-
842- key_T = key .transpose (- 1 , - 2 )
843- scores = torch .matmul (value , key_T )
844- output = torch .matmul (scores , query )
845-
846- output = output .float ()
847- output = output [:, :, :- 1 ] / (output [:, :, - 1 :] + self .eps )
848- output = torch .reshape (output , (batch_size , - 1 , height , width ))
849-
850- return output
851-
852- def quadratic_attention (self , qkv : torch .Tensor ) -> torch .Tensor :
853- batch_size , _ , height , width = list (qkv .size ())
854-
855- qkv = torch .reshape (qkv , (batch_size , - 1 , 3 * self .attention_head_dim , height * width ))
856- query , key , value = (
857- qkv [:, :, 0 : self .attention_head_dim ],
858- qkv [:, :, self .attention_head_dim : 2 * self .attention_head_dim ],
859- qkv [:, :, 2 * self .attention_head_dim :],
860- )
861-
862- query = self .kernel_nonlinearity (query )
863- key = self .kernel_nonlinearity (key )
864-
865- scores = torch .matmul (key .transpose (- 1 , - 2 ), query )
866-
867- original_dtype = scores .dtype
868- scores = scores .float ()
869- scores = scores / (torch .sum (scores , dim = 2 , keepdim = True ) + self .eps )
870- scores = scores .to (original_dtype )
871-
872- output = torch .matmul (value , scores )
873- output = torch .reshape (output , (batch_size , - 1 , height , width ))
874-
875- return output
827+ self .processor = SanaMultiscaleLinearAttnProcessor2_0 ()
828+ self .processor_quadratic = SanaMultiscaleQuadraticAttnProcessor2_0 ()
876829
877830 def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
878- residual = hidden_states
879-
880- # qkv = self.to_qkv(hidden_states)
881- hidden_states = hidden_states .movedim (1 , 3 )
882- query = self .to_q (hidden_states )
883- key = self .to_k (hidden_states )
884- value = self .to_v (hidden_states )
885- qkv = torch .cat ([query , key , value ], dim = 3 )
886- qkv = qkv .movedim (3 , 1 )
887-
888- multi_scale_qkv = [qkv ]
889- for block in self .to_qkv_multiscale :
890- multi_scale_qkv .append (block (qkv ))
831+ height , width = hidden_states .shape [- 2 :]
891832
892- qkv = torch .cat (multi_scale_qkv , dim = 1 )
893-
894- height , width = qkv .shape [- 2 :]
895833 if height * width > self .attention_head_dim :
896- hidden_states = self .linear_attention (qkv ).to (qkv .dtype )
897- else :
898- hidden_states = self .quadratic_attention (qkv )
899-
900- hidden_states = self .proj_out (hidden_states )
901-
902- if self .norm_type == "rms_norm" :
903- hidden_states = self .norm_out (hidden_states .movedim (1 , - 1 )).movedim (- 1 , 1 )
834+ hidden_states = self .processor (self , hidden_states )
904835 else :
905- hidden_states = self .norm_out ( hidden_states )
836+ hidden_states = self .processor_quadratic ( self , hidden_states )
906837
907- return hidden_states + residual
838+ return hidden_states
908839
909840
910841class AttnProcessor :
@@ -5160,6 +5091,109 @@ def __call__(
51605091 return hidden_states
51615092
51625093
5094+ class SanaMultiscaleLinearAttnProcessor2_0 :
5095+ r"""
5096+ Processor for implementing multiscale linear attention.
5097+ """
5098+
5099+ def __call__ (self , attn : SanaMultiscaleLinearAttention , hidden_states : torch .Tensor ) -> torch .Tensor :
5100+ residual = hidden_states
5101+
5102+ batch_size , _ , height , width = hidden_states .shape
5103+ original_dtype = hidden_states .dtype
5104+
5105+ hidden_states = hidden_states .movedim (1 , - 1 )
5106+ query = attn .to_q (hidden_states )
5107+ key = attn .to_k (hidden_states )
5108+ value = attn .to_v (hidden_states )
5109+ hidden_states = torch .cat ([query , key , value ], dim = 3 )
5110+ hidden_states = hidden_states .movedim (- 1 , 1 )
5111+
5112+ multiscale_hidden_states = [hidden_states ]
5113+ for block in attn .to_qkv_multiscale :
5114+ multiscale_hidden_states .append (block (hidden_states ))
5115+
5116+ hidden_states = torch .cat (multiscale_hidden_states , dim = 1 )
5117+
5118+ hidden_states = hidden_states .to (dtype = torch .float32 )
5119+ hidden_states = hidden_states .reshape (batch_size , - 1 , 3 * attn .attention_head_dim , height * width )
5120+
5121+ query , key , value = hidden_states .chunk (3 , dim = 2 )
5122+ query = attn .nonlinearity (query )
5123+ key = attn .nonlinearity (key )
5124+ value = F .pad (value , (0 , 0 , 0 , 1 ), mode = "constant" , value = 1 )
5125+
5126+ scores = torch .matmul (value , key .transpose (- 1 , - 2 ))
5127+ hidden_states = torch .matmul (scores , query )
5128+
5129+ hidden_states = hidden_states .to (dtype = torch .float32 )
5130+ hidden_states = hidden_states [:, :, :- 1 ] / (hidden_states [:, :, - 1 :] + attn .eps )
5131+ hidden_states = hidden_states .to (dtype = original_dtype )
5132+
5133+ hidden_states = torch .reshape (hidden_states , (batch_size , - 1 , height , width ))
5134+ hidden_states = attn .to_out (hidden_states .movedim (1 , - 1 )).movedim (- 1 , 1 )
5135+
5136+ if attn .norm_type == "rms_norm" :
5137+ hidden_states = attn .norm_out (hidden_states .movedim (1 , - 1 )).movedim (- 1 , 1 )
5138+ else :
5139+ hidden_states = attn .norm_out (hidden_states )
5140+
5141+ if attn .residual_connection :
5142+ hidden_states = hidden_states + residual
5143+
5144+ return hidden_states
5145+
5146+
5147+ class SanaMultiscaleQuadraticAttnProcessor2_0 :
5148+ r"""
5149+ Processor for implementing multiscale quadratic attention.
5150+ """
5151+
5152+ def __call__ (self , attn : SanaMultiscaleLinearAttention , hidden_states : torch .Tensor ) -> torch .Tensor :
5153+ residual = hidden_states
5154+
5155+ batch_size , _ , height , width = list (hidden_states .size ())
5156+ original_dtype = hidden_states .dtype
5157+
5158+ hidden_states = hidden_states .movedim (1 , - 1 )
5159+ query = attn .to_q (hidden_states )
5160+ key = attn .to_k (hidden_states )
5161+ value = attn .to_v (hidden_states )
5162+ hidden_states = torch .cat ([query , key , value ], dim = 3 )
5163+ hidden_states = hidden_states .movedim (- 1 , 1 )
5164+
5165+ multi_scale_qkv = [hidden_states ]
5166+ for block in attn .to_qkv_multiscale :
5167+ multi_scale_qkv .append (block (hidden_states ))
5168+
5169+ hidden_states = torch .cat (multi_scale_qkv , dim = 1 )
5170+
5171+ hidden_states = hidden_states .reshape (batch_size , - 1 , 3 * attn .attention_head_dim , height * width )
5172+
5173+ query , key , value = hidden_states .chunk (3 , dim = 2 )
5174+ query = attn .nonlinearity (query )
5175+ key = attn .nonlinearity (key )
5176+
5177+ scores = torch .matmul (key .transpose (- 1 , - 2 ), query )
5178+ scores = scores .to (dtype = torch .float32 )
5179+ scores = scores / (torch .sum (scores , dim = 2 , keepdim = True ) + attn .eps )
5180+ scores = scores .to (dtype = original_dtype )
5181+ hidden_states = torch .matmul (value , scores )
5182+
5183+ hidden_states = torch .reshape (hidden_states , (batch_size , - 1 , height , width ))
5184+ hidden_states = attn .to_out (hidden_states .movedim (1 , - 1 )).movedim (- 1 , 1 )
5185+
5186+ if attn .norm_type == "rms_norm" :
5187+ hidden_states = attn .norm_out (hidden_states .movedim (1 , - 1 )).movedim (- 1 , 1 )
5188+ else :
5189+ hidden_states = attn .norm_out (hidden_states )
5190+
5191+ if attn .residual_connection :
5192+ hidden_states = hidden_states + residual
5193+
5194+ return hidden_states
5195+
5196+
51635197class LoRAAttnProcessor :
51645198 def __init__ (self ):
51655199 pass
0 commit comments