@@ -175,6 +175,8 @@ def __call__(self, x, conditioning):
175175 scale_mlp , shift_mlp , gate_mlp , scale_attn , shift_attn , gate_attn = jnp .split (
176176 ada_params , 6 , axis = - 1 )
177177
178+ scale_mlp = jnp .clip (scale_mlp , - 10.0 , 10.0 )
179+ shift_mlp = jnp .clip (shift_mlp , - 10.0 , 10.0 )
178180 # Apply Layer Normalization
179181 norm = nn .LayerNorm (epsilon = self .norm_epsilon ,
180182 use_scale = False , use_bias = False , dtype = self .dtype )
@@ -191,48 +193,104 @@ def __call__(self, x, conditioning):
191193 # Return modulated outputs and gates
192194 return x_attn , gate_attn , x_mlp , gate_mlp
193195
196+ class AdaLNParams (nn .Module ): # Renamed for clarity
197+ features : int
198+ dtype : Optional [Dtype ] = None
199+ precision : PrecisionLike = None
194200
195- # --- DiT Block ---
201+ @nn .compact
202+ def __call__ (self , conditioning ):
203+ # Ensure conditioning is broadcastable if needed (e.g., [B, 1, D_cond])
204+ if conditioning .ndim == 2 :
205+ conditioning = jnp .expand_dims (conditioning , axis = 1 )
196206
207+ # Project conditioning to get 6 params per feature
208+ ada_params = nn .Dense (
209+ features = 6 * self .features ,
210+ dtype = self .dtype ,
211+ precision = self .precision ,
212+ kernel_init = nn .initializers .zeros ,
213+ name = "ada_proj"
214+ )(conditioning )
215+ # Return all params (or split if preferred, but maybe return tuple/dict)
216+ # Shape: [B, 1, 6*F]
217+ return ada_params # Or split and return tuple: jnp.split(ada_params, 6, axis=-1)
218+
219+ # --- DiT Block ---
197220class DiTBlock (nn .Module ):
198221 features : int
199222 num_heads : int
200- rope_emb : RotaryEmbedding # Pass RoPE module
223+ rope_emb : RotaryEmbedding
201224 mlp_ratio : int = 4
202- dropout_rate : float = 0.0 # Typically dropout is not used in diffusion models
225+ dropout_rate : float = 0.0
203226 dtype : Optional [Dtype ] = None
204227 precision : PrecisionLike = None
205- # Keep option, but RoPEAttention uses NormalAttention base
206- use_flash_attention : bool = False
228+ use_flash_attention : bool = False # Keep placeholder
207229 force_fp32_for_softmax : bool = True
208230 norm_epsilon : float = 1e-5
231+ use_gating : bool = True # Add flag to easily disable gating
209232
210233 def setup (self ):
211234 hidden_features = int (self .features * self .mlp_ratio )
212- self .ada_ln_zero = AdaLNZero (
213- self .features , dtype = self .dtype , precision = self .precision , norm_epsilon = self .norm_epsilon )
235+ # Get modulation parameters (scale, shift, gates)
236+ self .ada_params_module = AdaLNParams ( # Use the modified module
237+ self .features , dtype = self .dtype , precision = self .precision )
238+
239+ # Layer Norms - one before Attn, one before MLP
240+ self .norm1 = nn .LayerNorm (epsilon = self .norm_epsilon , use_scale = False , use_bias = False , dtype = self .dtype , name = "norm1" )
241+ self .norm2 = nn .LayerNorm (epsilon = self .norm_epsilon , use_scale = False , use_bias = False , dtype = self .dtype , name = "norm2" )
214242
215- # Use RoPEAttention
216243 self .attention = RoPEAttention (
217244 query_dim = self .features ,
218245 heads = self .num_heads ,
219246 dim_head = self .features // self .num_heads ,
220247 dtype = self .dtype ,
221248 precision = self .precision ,
222- use_bias = True , # Bias is common in DiT attention proj
249+ use_bias = True ,
223250 force_fp32_for_softmax = self .force_fp32_for_softmax ,
224- rope_emb = self .rope_emb # Pass RoPE module instance
251+ rope_emb = self .rope_emb
225252 )
226253
227- # Standard MLP block
228254 self .mlp = nn .Sequential ([
229- nn .Dense (features = hidden_features , dtype = self .dtype ,
230- precision = self .precision ),
231- nn .gelu ,
232- nn .Dense (features = self .features , dtype = self .dtype ,
233- precision = self .precision )
255+ nn .Dense (features = hidden_features , dtype = self .dtype , precision = self .precision ),
256+ nn .gelu , # Or swish as specified in SimpleDiT? Consider consistency.
257+ nn .Dense (features = self .features , dtype = self .dtype , precision = self .precision )
234258 ])
235259
260+ @nn .compact
261+ def __call__ (self , x , conditioning , freqs_cis ):
262+ # Get scale/shift/gate parameters
263+ # Shape: [B, 1, 6*F] -> split into 6 of [B, 1, F]
264+ scale_mlp , shift_mlp , gate_mlp , scale_attn , shift_attn , gate_attn = jnp .split (
265+ self .ada_params_module (conditioning ), 6 , axis = - 1
266+ )
267+
268+ # --- Attention Path ---
269+ residual = x
270+ norm_x_attn = self .norm1 (x )
271+ # Modulate after norm
272+ x_attn_modulated = norm_x_attn * (1 + scale_attn ) + shift_attn
273+ attn_output = self .attention (x_attn_modulated , context = None , freqs_cis = freqs_cis )
274+
275+ if self .use_gating :
276+ x = residual + gate_attn * attn_output
277+ else :
278+ x = residual + attn_output # Original DiT style without gate
279+
280+ # --- MLP Path ---
281+ residual = x
282+ norm_x_mlp = self .norm2 (x ) # Apply second LayerNorm
283+ # Modulate after norm
284+ x_mlp_modulated = norm_x_mlp * (1 + scale_mlp ) + shift_mlp
285+ mlp_output = self .mlp (x_mlp_modulated )
286+
287+ if self .use_gating :
288+ x = residual + gate_mlp * mlp_output
289+ else :
290+ x = residual + mlp_output # Original DiT style without gate
291+
292+ return x
293+
236294 @nn .compact
237295 def __call__ (self , x , conditioning , freqs_cis ):
238296 # x shape: [B, S, F]
0 commit comments