Skip to content

Commit 94269cc

Browse files
committed
fixes in dit
1 parent 0513b8b commit 94269cc

File tree

2 files changed

+77
-17
lines changed

2 files changed

+77
-17
lines changed

flaxdiff/models/simple_dit.py

Lines changed: 74 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,8 @@ def __call__(self, x, conditioning):
175175
scale_mlp, shift_mlp, gate_mlp, scale_attn, shift_attn, gate_attn = jnp.split(
176176
ada_params, 6, axis=-1)
177177

178+
scale_mlp = jnp.clip(scale_mlp, -10.0, 10.0)
179+
shift_mlp = jnp.clip(shift_mlp, -10.0, 10.0)
178180
# Apply Layer Normalization
179181
norm = nn.LayerNorm(epsilon=self.norm_epsilon,
180182
use_scale=False, use_bias=False, dtype=self.dtype)
@@ -191,48 +193,104 @@ def __call__(self, x, conditioning):
191193
# Return modulated outputs and gates
192194
return x_attn, gate_attn, x_mlp, gate_mlp
193195

196+
class AdaLNParams(nn.Module): # Renamed for clarity
197+
features: int
198+
dtype: Optional[Dtype] = None
199+
precision: PrecisionLike = None
194200

195-
# --- DiT Block ---
201+
@nn.compact
202+
def __call__(self, conditioning):
203+
# Ensure conditioning is broadcastable if needed (e.g., [B, 1, D_cond])
204+
if conditioning.ndim == 2:
205+
conditioning = jnp.expand_dims(conditioning, axis=1)
196206

207+
# Project conditioning to get 6 params per feature
208+
ada_params = nn.Dense(
209+
features=6 * self.features,
210+
dtype=self.dtype,
211+
precision=self.precision,
212+
kernel_init=nn.initializers.zeros,
213+
name="ada_proj"
214+
)(conditioning)
215+
# Return all params (or split if preferred, but maybe return tuple/dict)
216+
# Shape: [B, 1, 6*F]
217+
return ada_params # Or split and return tuple: jnp.split(ada_params, 6, axis=-1)
218+
219+
# --- DiT Block ---
197220
class DiTBlock(nn.Module):
198221
features: int
199222
num_heads: int
200-
rope_emb: RotaryEmbedding # Pass RoPE module
223+
rope_emb: RotaryEmbedding
201224
mlp_ratio: int = 4
202-
dropout_rate: float = 0.0 # Typically dropout is not used in diffusion models
225+
dropout_rate: float = 0.0
203226
dtype: Optional[Dtype] = None
204227
precision: PrecisionLike = None
205-
# Keep option, but RoPEAttention uses NormalAttention base
206-
use_flash_attention: bool = False
228+
use_flash_attention: bool = False # Keep placeholder
207229
force_fp32_for_softmax: bool = True
208230
norm_epsilon: float = 1e-5
231+
use_gating: bool = True # Add flag to easily disable gating
209232

210233
def setup(self):
211234
hidden_features = int(self.features * self.mlp_ratio)
212-
self.ada_ln_zero = AdaLNZero(
213-
self.features, dtype=self.dtype, precision=self.precision, norm_epsilon=self.norm_epsilon)
235+
# Get modulation parameters (scale, shift, gates)
236+
self.ada_params_module = AdaLNParams( # Use the modified module
237+
self.features, dtype=self.dtype, precision=self.precision)
238+
239+
# Layer Norms - one before Attn, one before MLP
240+
self.norm1 = nn.LayerNorm(epsilon=self.norm_epsilon, use_scale=False, use_bias=False, dtype=self.dtype, name="norm1")
241+
self.norm2 = nn.LayerNorm(epsilon=self.norm_epsilon, use_scale=False, use_bias=False, dtype=self.dtype, name="norm2")
214242

215-
# Use RoPEAttention
216243
self.attention = RoPEAttention(
217244
query_dim=self.features,
218245
heads=self.num_heads,
219246
dim_head=self.features // self.num_heads,
220247
dtype=self.dtype,
221248
precision=self.precision,
222-
use_bias=True, # Bias is common in DiT attention proj
249+
use_bias=True,
223250
force_fp32_for_softmax=self.force_fp32_for_softmax,
224-
rope_emb=self.rope_emb # Pass RoPE module instance
251+
rope_emb=self.rope_emb
225252
)
226253

227-
# Standard MLP block
228254
self.mlp = nn.Sequential([
229-
nn.Dense(features=hidden_features, dtype=self.dtype,
230-
precision=self.precision),
231-
nn.gelu,
232-
nn.Dense(features=self.features, dtype=self.dtype,
233-
precision=self.precision)
255+
nn.Dense(features=hidden_features, dtype=self.dtype, precision=self.precision),
256+
nn.gelu, # Or swish as specified in SimpleDiT? Consider consistency.
257+
nn.Dense(features=self.features, dtype=self.dtype, precision=self.precision)
234258
])
235259

260+
@nn.compact
261+
def __call__(self, x, conditioning, freqs_cis):
262+
# Get scale/shift/gate parameters
263+
# Shape: [B, 1, 6*F] -> split into 6 of [B, 1, F]
264+
scale_mlp, shift_mlp, gate_mlp, scale_attn, shift_attn, gate_attn = jnp.split(
265+
self.ada_params_module(conditioning), 6, axis=-1
266+
)
267+
268+
# --- Attention Path ---
269+
residual = x
270+
norm_x_attn = self.norm1(x)
271+
# Modulate after norm
272+
x_attn_modulated = norm_x_attn * (1 + scale_attn) + shift_attn
273+
attn_output = self.attention(x_attn_modulated, context=None, freqs_cis=freqs_cis)
274+
275+
if self.use_gating:
276+
x = residual + gate_attn * attn_output
277+
else:
278+
x = residual + attn_output # Original DiT style without gate
279+
280+
# --- MLP Path ---
281+
residual = x
282+
norm_x_mlp = self.norm2(x) # Apply second LayerNorm
283+
# Modulate after norm
284+
x_mlp_modulated = norm_x_mlp * (1 + scale_mlp) + shift_mlp
285+
mlp_output = self.mlp(x_mlp_modulated)
286+
287+
if self.use_gating:
288+
x = residual + gate_mlp * mlp_output
289+
else:
290+
x = residual + mlp_output # Original DiT style without gate
291+
292+
return x
293+
236294
@nn.compact
237295
def __call__(self, x, conditioning, freqs_cis):
238296
# x shape: [B, S, F]

flaxdiff/models/simple_mmdit.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,8 @@ def __call__(self, x, t_emb, text_emb):
208208
scale_mlp, shift_mlp, gate_mlp, scale_attn, shift_attn, gate_attn = jnp.split(
209209
ada_params, 6, axis=-1) # Each shape: [B, 1, F]
210210

211+
scale_mlp = jnp.clip(scale_mlp, -10.0, 10.0)
212+
shift_mlp = jnp.clip(shift_mlp, -10.0, 10.0)
211213
# Apply modulation for Attention path (broadcasting handled by JAX)
212214
x_attn = norm_x * (1 + scale_attn) + shift_attn
213215

@@ -250,7 +252,7 @@ def setup(self):
250252
precision=self.precision,
251253
use_bias=True, # Bias is common in DiT attention proj
252254
force_fp32_for_softmax=self.force_fp32_for_softmax,
253-
rope_emb=self.rope # Pass RoPE module instance
255+
rope_emb=self.rope_emb # Pass RoPE module instance
254256
)
255257

256258
# Standard MLP block remains the same

0 commit comments

Comments
 (0)