@@ -48,11 +48,11 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10
4848 return embedding
4949
5050class MLPEmbedder (nn .Module ):
51- def __init__ (self , in_dim : int , hidden_dim : int , dtype = None , device = None , operations = None ):
51+ def __init__ (self , in_dim : int , hidden_dim : int , bias = True , dtype = None , device = None , operations = None ):
5252 super ().__init__ ()
53- self .in_layer = operations .Linear (in_dim , hidden_dim , bias = True , dtype = dtype , device = device )
53+ self .in_layer = operations .Linear (in_dim , hidden_dim , bias = bias , dtype = dtype , device = device )
5454 self .silu = nn .SiLU ()
55- self .out_layer = operations .Linear (hidden_dim , hidden_dim , bias = True , dtype = dtype , device = device )
55+ self .out_layer = operations .Linear (hidden_dim , hidden_dim , bias = bias , dtype = dtype , device = device )
5656
5757 def forward (self , x : Tensor ) -> Tensor :
5858 return self .out_layer (self .silu (self .in_layer (x )))
@@ -80,14 +80,14 @@ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple:
8080
8181
8282class SelfAttention (nn .Module ):
83- def __init__ (self , dim : int , num_heads : int = 8 , qkv_bias : bool = False , dtype = None , device = None , operations = None ):
83+ def __init__ (self , dim : int , num_heads : int = 8 , qkv_bias : bool = False , proj_bias : bool = True , dtype = None , device = None , operations = None ):
8484 super ().__init__ ()
8585 self .num_heads = num_heads
8686 head_dim = dim // num_heads
8787
8888 self .qkv = operations .Linear (dim , dim * 3 , bias = qkv_bias , dtype = dtype , device = device )
8989 self .norm = QKNorm (head_dim , dtype = dtype , device = device , operations = operations )
90- self .proj = operations .Linear (dim , dim , dtype = dtype , device = device )
90+ self .proj = operations .Linear (dim , dim , bias = proj_bias , dtype = dtype , device = device )
9191
9292
9393@dataclass
@@ -98,11 +98,11 @@ class ModulationOut:
9898
9999
100100class Modulation (nn .Module ):
101- def __init__ (self , dim : int , double : bool , dtype = None , device = None , operations = None ):
101+ def __init__ (self , dim : int , double : bool , bias = True , dtype = None , device = None , operations = None ):
102102 super ().__init__ ()
103103 self .is_double = double
104104 self .multiplier = 6 if double else 3
105- self .lin = operations .Linear (dim , self .multiplier * dim , bias = True , dtype = dtype , device = device )
105+ self .lin = operations .Linear (dim , self .multiplier * dim , bias = bias , dtype = dtype , device = device )
106106
107107 def forward (self , vec : Tensor ) -> tuple :
108108 if vec .ndim == 2 :
@@ -129,8 +129,18 @@ def apply_mod(tensor, m_mult, m_add=None, modulation_dims=None):
129129 return tensor
130130
131131
132+ class SiLUActivation (nn .Module ):
133+ def __init__ (self ):
134+ super ().__init__ ()
135+ self .gate_fn = nn .SiLU ()
136+
137+ def forward (self , x : Tensor ) -> Tensor :
138+ x1 , x2 = x .chunk (2 , dim = - 1 )
139+ return self .gate_fn (x1 ) * x2
140+
141+
132142class DoubleStreamBlock (nn .Module ):
133- def __init__ (self , hidden_size : int , num_heads : int , mlp_ratio : float , qkv_bias : bool = False , flipped_img_txt = False , modulation = True , dtype = None , device = None , operations = None ):
143+ def __init__ (self , hidden_size : int , num_heads : int , mlp_ratio : float , qkv_bias : bool = False , flipped_img_txt = False , modulation = True , mlp_silu_act = False , proj_bias = True , dtype = None , device = None , operations = None ):
134144 super ().__init__ ()
135145
136146 mlp_hidden_dim = int (hidden_size * mlp_ratio )
@@ -142,27 +152,44 @@ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias:
142152 self .img_mod = Modulation (hidden_size , double = True , dtype = dtype , device = device , operations = operations )
143153
144154 self .img_norm1 = operations .LayerNorm (hidden_size , elementwise_affine = False , eps = 1e-6 , dtype = dtype , device = device )
145- self .img_attn = SelfAttention (dim = hidden_size , num_heads = num_heads , qkv_bias = qkv_bias , dtype = dtype , device = device , operations = operations )
155+ self .img_attn = SelfAttention (dim = hidden_size , num_heads = num_heads , qkv_bias = qkv_bias , proj_bias = proj_bias , dtype = dtype , device = device , operations = operations )
146156
147157 self .img_norm2 = operations .LayerNorm (hidden_size , elementwise_affine = False , eps = 1e-6 , dtype = dtype , device = device )
148- self .img_mlp = nn .Sequential (
149- operations .Linear (hidden_size , mlp_hidden_dim , bias = True , dtype = dtype , device = device ),
150- nn .GELU (approximate = "tanh" ),
151- operations .Linear (mlp_hidden_dim , hidden_size , bias = True , dtype = dtype , device = device ),
152- )
158+
159+ if mlp_silu_act :
160+ self .img_mlp = nn .Sequential (
161+ operations .Linear (hidden_size , mlp_hidden_dim * 2 , bias = False , dtype = dtype , device = device ),
162+ SiLUActivation (),
163+ operations .Linear (mlp_hidden_dim , hidden_size , bias = False , dtype = dtype , device = device ),
164+ )
165+ else :
166+ self .img_mlp = nn .Sequential (
167+ operations .Linear (hidden_size , mlp_hidden_dim , bias = True , dtype = dtype , device = device ),
168+ nn .GELU (approximate = "tanh" ),
169+ operations .Linear (mlp_hidden_dim , hidden_size , bias = True , dtype = dtype , device = device ),
170+ )
153171
154172 if self .modulation :
155173 self .txt_mod = Modulation (hidden_size , double = True , dtype = dtype , device = device , operations = operations )
156174
157175 self .txt_norm1 = operations .LayerNorm (hidden_size , elementwise_affine = False , eps = 1e-6 , dtype = dtype , device = device )
158- self .txt_attn = SelfAttention (dim = hidden_size , num_heads = num_heads , qkv_bias = qkv_bias , dtype = dtype , device = device , operations = operations )
176+ self .txt_attn = SelfAttention (dim = hidden_size , num_heads = num_heads , qkv_bias = qkv_bias , proj_bias = proj_bias , dtype = dtype , device = device , operations = operations )
159177
160178 self .txt_norm2 = operations .LayerNorm (hidden_size , elementwise_affine = False , eps = 1e-6 , dtype = dtype , device = device )
161- self .txt_mlp = nn .Sequential (
162- operations .Linear (hidden_size , mlp_hidden_dim , bias = True , dtype = dtype , device = device ),
163- nn .GELU (approximate = "tanh" ),
164- operations .Linear (mlp_hidden_dim , hidden_size , bias = True , dtype = dtype , device = device ),
165- )
179+
180+ if mlp_silu_act :
181+ self .txt_mlp = nn .Sequential (
182+ operations .Linear (hidden_size , mlp_hidden_dim * 2 , bias = False , dtype = dtype , device = device ),
183+ SiLUActivation (),
184+ operations .Linear (mlp_hidden_dim , hidden_size , bias = False , dtype = dtype , device = device ),
185+ )
186+ else :
187+ self .txt_mlp = nn .Sequential (
188+ operations .Linear (hidden_size , mlp_hidden_dim , bias = True , dtype = dtype , device = device ),
189+ nn .GELU (approximate = "tanh" ),
190+ operations .Linear (mlp_hidden_dim , hidden_size , bias = True , dtype = dtype , device = device ),
191+ )
192+
166193 self .flipped_img_txt = flipped_img_txt
167194
168195 def forward (self , img : Tensor , txt : Tensor , vec : Tensor , pe : Tensor , attn_mask = None , modulation_dims_img = None , modulation_dims_txt = None , transformer_options = {}):
@@ -246,6 +273,8 @@ def __init__(
246273 mlp_ratio : float = 4.0 ,
247274 qk_scale : float = None ,
248275 modulation = True ,
276+ mlp_silu_act = False ,
277+ bias = True ,
249278 dtype = None ,
250279 device = None ,
251280 operations = None
@@ -257,17 +286,24 @@ def __init__(
257286 self .scale = qk_scale or head_dim ** - 0.5
258287
259288 self .mlp_hidden_dim = int (hidden_size * mlp_ratio )
289+
290+ self .mlp_hidden_dim_first = self .mlp_hidden_dim
291+ if mlp_silu_act :
292+ self .mlp_hidden_dim_first = int (hidden_size * mlp_ratio * 2 )
293+ self .mlp_act = SiLUActivation ()
294+ else :
295+ self .mlp_act = nn .GELU (approximate = "tanh" )
296+
260297 # qkv and mlp_in
261- self .linear1 = operations .Linear (hidden_size , hidden_size * 3 + self .mlp_hidden_dim , dtype = dtype , device = device )
298+ self .linear1 = operations .Linear (hidden_size , hidden_size * 3 + self .mlp_hidden_dim_first , bias = bias , dtype = dtype , device = device )
262299 # proj and mlp_out
263- self .linear2 = operations .Linear (hidden_size + self .mlp_hidden_dim , hidden_size , dtype = dtype , device = device )
300+ self .linear2 = operations .Linear (hidden_size + self .mlp_hidden_dim , hidden_size , bias = bias , dtype = dtype , device = device )
264301
265302 self .norm = QKNorm (head_dim , dtype = dtype , device = device , operations = operations )
266303
267304 self .hidden_size = hidden_size
268305 self .pre_norm = operations .LayerNorm (hidden_size , elementwise_affine = False , eps = 1e-6 , dtype = dtype , device = device )
269306
270- self .mlp_act = nn .GELU (approximate = "tanh" )
271307 if modulation :
272308 self .modulation = Modulation (hidden_size , double = False , dtype = dtype , device = device , operations = operations )
273309 else :
@@ -279,7 +315,7 @@ def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation
279315 else :
280316 mod = vec
281317
282- qkv , mlp = torch .split (self .linear1 (apply_mod (self .pre_norm (x ), (1 + mod .scale ), mod .shift , modulation_dims )), [3 * self .hidden_size , self .mlp_hidden_dim ], dim = - 1 )
318+ qkv , mlp = torch .split (self .linear1 (apply_mod (self .pre_norm (x ), (1 + mod .scale ), mod .shift , modulation_dims )), [3 * self .hidden_size , self .mlp_hidden_dim_first ], dim = - 1 )
283319
284320 q , k , v = qkv .view (qkv .shape [0 ], qkv .shape [1 ], 3 , self .num_heads , - 1 ).permute (2 , 0 , 3 , 1 , 4 )
285321 del qkv
@@ -298,11 +334,11 @@ def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation
298334
299335
300336class LastLayer (nn .Module ):
301- def __init__ (self , hidden_size : int , patch_size : int , out_channels : int , dtype = None , device = None , operations = None ):
337+ def __init__ (self , hidden_size : int , patch_size : int , out_channels : int , bias = True , dtype = None , device = None , operations = None ):
302338 super ().__init__ ()
303339 self .norm_final = operations .LayerNorm (hidden_size , elementwise_affine = False , eps = 1e-6 , dtype = dtype , device = device )
304- self .linear = operations .Linear (hidden_size , patch_size * patch_size * out_channels , bias = True , dtype = dtype , device = device )
305- self .adaLN_modulation = nn .Sequential (nn .SiLU (), operations .Linear (hidden_size , 2 * hidden_size , bias = True , dtype = dtype , device = device ))
340+ self .linear = operations .Linear (hidden_size , patch_size * patch_size * out_channels , bias = bias , dtype = dtype , device = device )
341+ self .adaLN_modulation = nn .Sequential (nn .SiLU (), operations .Linear (hidden_size , 2 * hidden_size , bias = bias , dtype = dtype , device = device ))
306342
307343 def forward (self , x : Tensor , vec : Tensor , modulation_dims = None ) -> Tensor :
308344 if vec .ndim == 2 :
0 commit comments