@@ -57,6 +57,35 @@ def __init__(self, in_dim: int, hidden_dim: int, bias=True, dtype=None, device=N
5757 def forward (self , x : Tensor ) -> Tensor :
5858 return self .out_layer (self .silu (self .in_layer (x )))
5959
60+ class YakMLP (nn .Module ):
61+ def __init__ (self , hidden_size : int , intermediate_size : int , dtype = None , device = None , operations = None ):
62+ super ().__init__ ()
63+ self .hidden_size = hidden_size
64+ self .intermediate_size = intermediate_size
65+ self .gate_proj = operations .Linear (self .hidden_size , self .intermediate_size , bias = True , dtype = dtype , device = device )
66+ self .up_proj = operations .Linear (self .hidden_size , self .intermediate_size , bias = True , dtype = dtype , device = device )
67+ self .down_proj = operations .Linear (self .intermediate_size , self .hidden_size , bias = True , dtype = dtype , device = device )
68+ self .act_fn = nn .SiLU ()
69+
70+ def forward (self , x : Tensor ) -> Tensor :
71+ down_proj = self .down_proj (self .act_fn (self .gate_proj (x )) * self .up_proj (x ))
72+ return down_proj
73+
74+ def build_mlp (hidden_size , mlp_hidden_dim , mlp_silu_act = False , yak_mlp = False , dtype = None , device = None , operations = None ):
75+ if yak_mlp :
76+ return YakMLP (hidden_size , mlp_hidden_dim , dtype = dtype , device = device , operations = operations )
77+ if mlp_silu_act :
78+ return nn .Sequential (
79+ operations .Linear (hidden_size , mlp_hidden_dim * 2 , bias = False , dtype = dtype , device = device ),
80+ SiLUActivation (),
81+ operations .Linear (mlp_hidden_dim , hidden_size , bias = False , dtype = dtype , device = device ),
82+ )
83+ else :
84+ return nn .Sequential (
85+ operations .Linear (hidden_size , mlp_hidden_dim , bias = True , dtype = dtype , device = device ),
86+ nn .GELU (approximate = "tanh" ),
87+ operations .Linear (mlp_hidden_dim , hidden_size , bias = True , dtype = dtype , device = device ),
88+ )
6089
6190class RMSNorm (torch .nn .Module ):
6291 def __init__ (self , dim : int , dtype = None , device = None , operations = None ):
@@ -140,7 +169,7 @@ def forward(self, x: Tensor) -> Tensor:
140169
141170
142171class DoubleStreamBlock (nn .Module ):
143- def __init__ (self , hidden_size : int , num_heads : int , mlp_ratio : float , qkv_bias : bool = False , flipped_img_txt = False , modulation = True , mlp_silu_act = False , proj_bias = True , dtype = None , device = None , operations = None ):
172+ def __init__ (self , hidden_size : int , num_heads : int , mlp_ratio : float , qkv_bias : bool = False , flipped_img_txt = False , modulation = True , mlp_silu_act = False , proj_bias = True , yak_mlp = False , dtype = None , device = None , operations = None ):
144173 super ().__init__ ()
145174
146175 mlp_hidden_dim = int (hidden_size * mlp_ratio )
@@ -156,18 +185,7 @@ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias:
156185
157186 self .img_norm2 = operations .LayerNorm (hidden_size , elementwise_affine = False , eps = 1e-6 , dtype = dtype , device = device )
158187
159- if mlp_silu_act :
160- self .img_mlp = nn .Sequential (
161- operations .Linear (hidden_size , mlp_hidden_dim * 2 , bias = False , dtype = dtype , device = device ),
162- SiLUActivation (),
163- operations .Linear (mlp_hidden_dim , hidden_size , bias = False , dtype = dtype , device = device ),
164- )
165- else :
166- self .img_mlp = nn .Sequential (
167- operations .Linear (hidden_size , mlp_hidden_dim , bias = True , dtype = dtype , device = device ),
168- nn .GELU (approximate = "tanh" ),
169- operations .Linear (mlp_hidden_dim , hidden_size , bias = True , dtype = dtype , device = device ),
170- )
188+ self .img_mlp = build_mlp (hidden_size , mlp_hidden_dim , mlp_silu_act = mlp_silu_act , yak_mlp = yak_mlp , dtype = dtype , device = device , operations = operations )
171189
172190 if self .modulation :
173191 self .txt_mod = Modulation (hidden_size , double = True , dtype = dtype , device = device , operations = operations )
@@ -177,18 +195,7 @@ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias:
177195
178196 self .txt_norm2 = operations .LayerNorm (hidden_size , elementwise_affine = False , eps = 1e-6 , dtype = dtype , device = device )
179197
180- if mlp_silu_act :
181- self .txt_mlp = nn .Sequential (
182- operations .Linear (hidden_size , mlp_hidden_dim * 2 , bias = False , dtype = dtype , device = device ),
183- SiLUActivation (),
184- operations .Linear (mlp_hidden_dim , hidden_size , bias = False , dtype = dtype , device = device ),
185- )
186- else :
187- self .txt_mlp = nn .Sequential (
188- operations .Linear (hidden_size , mlp_hidden_dim , bias = True , dtype = dtype , device = device ),
189- nn .GELU (approximate = "tanh" ),
190- operations .Linear (mlp_hidden_dim , hidden_size , bias = True , dtype = dtype , device = device ),
191- )
198+ self .txt_mlp = build_mlp (hidden_size , mlp_hidden_dim , mlp_silu_act = mlp_silu_act , yak_mlp = yak_mlp , dtype = dtype , device = device , operations = operations )
192199
193200 self .flipped_img_txt = flipped_img_txt
194201
@@ -275,6 +282,7 @@ def __init__(
275282 modulation = True ,
276283 mlp_silu_act = False ,
277284 bias = True ,
285+ yak_mlp = False ,
278286 dtype = None ,
279287 device = None ,
280288 operations = None
@@ -288,12 +296,17 @@ def __init__(
288296 self .mlp_hidden_dim = int (hidden_size * mlp_ratio )
289297
290298 self .mlp_hidden_dim_first = self .mlp_hidden_dim
299+ self .yak_mlp = yak_mlp
291300 if mlp_silu_act :
292301 self .mlp_hidden_dim_first = int (hidden_size * mlp_ratio * 2 )
293302 self .mlp_act = SiLUActivation ()
294303 else :
295304 self .mlp_act = nn .GELU (approximate = "tanh" )
296305
306+ if self .yak_mlp :
307+ self .mlp_hidden_dim_first *= 2
308+ self .mlp_act = nn .SiLU ()
309+
297310 # qkv and mlp_in
298311 self .linear1 = operations .Linear (hidden_size , hidden_size * 3 + self .mlp_hidden_dim_first , bias = bias , dtype = dtype , device = device )
299312 # proj and mlp_out
@@ -325,7 +338,10 @@ def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation
325338 attn = attention (q , k , v , pe = pe , mask = attn_mask , transformer_options = transformer_options )
326339 del q , k , v
327340 # compute activation in mlp stream, cat again and run second linear layer
328- mlp = self .mlp_act (mlp )
341+ if self .yak_mlp :
342+ mlp = self .mlp_act (mlp [..., self .mlp_hidden_dim_first // 2 :]) * mlp [..., :self .mlp_hidden_dim_first // 2 ]
343+ else :
344+ mlp = self .mlp_act (mlp )
329345 output = self .linear2 (torch .cat ((attn , mlp ), 2 ))
330346 x += apply_mod (output , mod .gate , None , modulation_dims )
331347 if x .dtype == torch .float16 :
0 commit comments