Skip to content

Commit 6f17a43

Browse files
committed
fixed patchify in mmdit
1 parent e3322b7 commit 6f17a43

File tree

2 files changed

+61
-30
lines changed

2 files changed

+61
-30
lines changed

flaxdiff/models/simple_dit.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -325,12 +325,13 @@ def setup(self):
325325
)
326326

327327
# Add projection layer for Hilbert patches
328-
self.hilbert_proj = nn.Dense(
329-
features=self.emb_features,
330-
dtype=self.dtype,
331-
precision=self.precision,
332-
name="hilbert_projection"
333-
)
328+
if self.use_hilbert:
329+
self.hilbert_proj = nn.Dense(
330+
features=self.emb_features,
331+
dtype=self.dtype,
332+
precision=self.precision,
333+
name="hilbert_projection"
334+
)
334335

335336
# Time embedding projection
336337
self.time_embed = nn.Sequential([

flaxdiff/models/simple_mmdit.py

Lines changed: 54 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,14 @@ def setup(self):
331331
precision=self.precision) # Final projection
332332
], name="time_embed")
333333

334+
# Add projection layer for Hilbert patches
335+
if self.use_hilbert:
336+
self.hilbert_proj = nn.Dense(
337+
features=self.emb_features,
338+
dtype=self.dtype,
339+
precision=self.precision,
340+
name="hilbert_projection"
341+
)
334342
# Text context projection (output dim: emb_features)
335343
# Input dim depends on the text encoder output, assumed to be handled externally
336344
self.text_proj = nn.Dense(features=self.emb_features, dtype=self.dtype,
@@ -383,18 +391,17 @@ def __call__(self, x, temb, textcontext): # textcontext is required
383391
assert textcontext is not None, "textcontext must be provided for SimpleMMDiT"
384392

385393
# 1. Patch Embedding
386-
patches = self.patch_embed(x) # Shape: [B, num_patches, emb_features]
387-
num_patches = patches.shape[1]
388-
389-
# Optional Hilbert reorder
390-
hilbert_inv_idx = None
391394
if self.use_hilbert:
392-
idx = hilbert_indices(H // self.patch_size, W // self.patch_size)
393-
hilbert_inv_idx = inverse_permutation(
394-
idx) # Store inverse for unpatchify
395-
patches = patches[:, idx, :]
395+
# Use hilbert_patchify which handles both patchification and reordering
396+
patches_raw, hilbert_inv_idx = hilbert_patchify(x, self.patch_size) # Shape [B, S, P*P*C]
397+
# Apply projection
398+
patches = self.hilbert_proj(patches_raw) # Shape [B, S, emb_features]
399+
else:
400+
patches = self.patch_embed(x) # Shape: [B, num_patches, emb_features]
401+
hilbert_inv_idx = None
396402

397-
x_seq = patches # Shape: [B, num_patches, emb_features]
403+
num_patches = patches.shape[1]
404+
x_seq = patches
398405

399406
# 2. Prepare Conditioning Signals
400407
t_emb = self.time_embed(temb) # Shape: [B, emb_features]
@@ -419,21 +426,35 @@ def __call__(self, x, temb, textcontext): # textcontext is required
419426
x_seq = self.final_proj(x_seq)
420427

421428
# 6. Unpatchify
422-
# Optional Hilbert unorder
423-
if self.use_hilbert and hilbert_inv_idx is not None:
424-
x_seq = x_seq[:, hilbert_inv_idx, :]
425-
426-
# Determine output channels for unpatchify
427-
final_out_channels = self.output_channels * \
428-
(2 if self.learn_sigma else 1)
429-
430-
# Reshape back to image space
431-
# Shape: [B, H, W, C (*2 if learn_sigma)]
432-
out = unpatchify(x_seq, channels=final_out_channels)
429+
if self.use_hilbert:
430+
# For Hilbert mode, we need to use the specialized unpatchify function
431+
if self.learn_sigma:
432+
# Split into mean and variance predictions
433+
x_mean, x_logvar = jnp.split(x_seq, 2, axis=-1)
434+
x_image = hilbert_unpatchify(x_mean, hilbert_inv_idx, self.patch_size, H, W, self.output_channels)
435+
# If needed, also unpack the logvar
436+
# logvar_image = hilbert_unpatchify(x_logvar, hilbert_inv_idx, self.patch_size, H, W, self.output_channels)
437+
# return x_image, logvar_image
438+
return x_image
439+
else:
440+
x_image = hilbert_unpatchify(x_seq, hilbert_inv_idx, self.patch_size, H, W, self.output_channels)
441+
return x_image
442+
else:
443+
# Standard patch ordering - use the existing unpatchify function
444+
if self.learn_sigma:
445+
# Split into mean and variance predictions
446+
x_mean, x_logvar = jnp.split(x_seq, 2, axis=-1)
447+
x = unpatchify(x_mean, channels=self.output_channels)
448+
# Return both mean and logvar if needed by the loss function
449+
# For now, just returning the mean prediction like standard diffusion models
450+
# logvar = unpatchify(x_logvar, channels=self.output_channels)
451+
# return x, logvar
452+
return x
453+
else:
454+
# Shape: [B, H, W, C]
455+
x = unpatchify(x_seq, channels=self.output_channels)
456+
return x
433457

434-
# If learn_sigma is True, the output has doubled channels.
435-
# The caller is responsible for splitting if needed.
436-
return out
437458

438459

439460
# --- Hierarchical MM-DiT components ---
@@ -573,6 +594,15 @@ def setup(self):
573594
name="text_context_proj"
574595
)
575596

597+
# Add projection layer for Hilbert patches
598+
if self.use_hilbert:
599+
self.hilbert_proj = nn.Dense(
600+
features=self.emb_features,
601+
dtype=self.dtype,
602+
precision=self.precision,
603+
name="hilbert_projection"
604+
)
605+
576606
# Create RoPE embeddings for each stage
577607
self.ropes = [
578608
RotaryEmbedding(

0 commit comments

Comments
 (0)