@@ -106,7 +106,7 @@ def tiled_forward(self, forward_fn, model_input, tile_size, tile_stride, tile_ba
106106 return model_output
107107
108108
109- class Attention (torch .nn .Module ):
109+ class ConvAttention (torch .nn .Module ):
110110
111111 def __init__ (self , q_dim , num_heads , head_dim , kv_dim = None , bias_q = False , bias_kv = False , bias_out = False ):
112112 super ().__init__ ()
@@ -115,20 +115,25 @@ def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_k
115115 self .num_heads = num_heads
116116 self .head_dim = head_dim
117117
118- self .to_q = torch .nn .Linear (q_dim , dim_inner , bias = bias_q )
119- self .to_k = torch .nn .Linear (kv_dim , dim_inner , bias = bias_kv )
120- self .to_v = torch .nn .Linear (kv_dim , dim_inner , bias = bias_kv )
121- self .to_out = torch .nn .Linear (dim_inner , q_dim , bias = bias_out )
118+ self .to_q = torch .nn .Conv2d (q_dim , dim_inner , kernel_size = ( 1 , 1 ) , bias = bias_q )
119+ self .to_k = torch .nn .Conv2d (kv_dim , dim_inner , kernel_size = ( 1 , 1 ) , bias = bias_kv )
120+ self .to_v = torch .nn .Conv2d (kv_dim , dim_inner , kernel_size = ( 1 , 1 ) , bias = bias_kv )
121+ self .to_out = torch .nn .Conv2d (dim_inner , q_dim , kernel_size = ( 1 , 1 ) , bias = bias_out )
122122
123123 def forward (self , hidden_states , encoder_hidden_states = None , attn_mask = None ):
124124 if encoder_hidden_states is None :
125125 encoder_hidden_states = hidden_states
126126
127127 batch_size = encoder_hidden_states .shape [0 ]
128128
129- q = self .to_q (hidden_states )
130- k = self .to_k (encoder_hidden_states )
131- v = self .to_v (encoder_hidden_states )
129+ conv_input = rearrange (hidden_states , "B L C -> B C L 1" )
130+ q = self .to_q (conv_input )
131+ q = rearrange (q [:, :, :, 0 ], "B C L -> B L C" )
132+ conv_input = rearrange (encoder_hidden_states , "B L C -> B C L 1" )
133+ k = self .to_k (conv_input )
134+ v = self .to_v (conv_input )
135+ k = rearrange (k [:, :, :, 0 ], "B C L -> B L C" )
136+ v = rearrange (v [:, :, :, 0 ], "B C L -> B L C" )
132137
133138 q = q .view (batch_size , - 1 , self .num_heads , self .head_dim ).transpose (1 , 2 )
134139 k = k .view (batch_size , - 1 , self .num_heads , self .head_dim ).transpose (1 , 2 )
@@ -138,7 +143,9 @@ def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
138143 hidden_states = hidden_states .transpose (1 , 2 ).reshape (batch_size , - 1 , self .num_heads * self .head_dim )
139144 hidden_states = hidden_states .to (q .dtype )
140145
141- hidden_states = self .to_out (hidden_states )
146+ conv_input = rearrange (hidden_states , "B L C -> B C L 1" )
147+ hidden_states = self .to_out (conv_input )
148+ hidden_states = rearrange (hidden_states [:, :, :, 0 ], "B C L -> B L C" )
142149
143150 return hidden_states
144151
@@ -152,7 +159,7 @@ def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_lay
152159 self .norm = torch .nn .GroupNorm (num_groups = norm_num_groups , num_channels = in_channels , eps = eps , affine = True )
153160
154161 self .transformer_blocks = torch .nn .ModuleList ([
155- Attention (
162+ ConvAttention (
156163 inner_dim ,
157164 num_attention_heads ,
158165 attention_head_dim ,
@@ -236,7 +243,7 @@ def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs):
236243 return hidden_states , time_emb , text_emb , res_stack
237244
238245
239- class SD3VAEDecoder (torch .nn .Module ):
246+ class FluxVAEDecoder (torch .nn .Module ):
240247 def __init__ (self ):
241248 super ().__init__ ()
242249 self .scaling_factor = 0.3611
@@ -308,7 +315,7 @@ def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
308315 return hidden_states
309316
310317
311- class SD3VAEEncoder (torch .nn .Module ):
318+ class FluxVAEEncoder (torch .nn .Module ):
312319 def __init__ (self ):
313320 super ().__init__ ()
314321 self .scaling_factor = 0.3611
0 commit comments