@@ -331,6 +331,14 @@ def setup(self):
331331 precision = self .precision ) # Final projection
332332 ], name = "time_embed" )
333333
334+ # Add projection layer for Hilbert patches
335+ if self .use_hilbert :
336+ self .hilbert_proj = nn .Dense (
337+ features = self .emb_features ,
338+ dtype = self .dtype ,
339+ precision = self .precision ,
340+ name = "hilbert_projection"
341+ )
334342 # Text context projection (output dim: emb_features)
335343 # Input dim depends on the text encoder output, assumed to be handled externally
336344 self .text_proj = nn .Dense (features = self .emb_features , dtype = self .dtype ,
@@ -383,18 +391,17 @@ def __call__(self, x, temb, textcontext): # textcontext is required
383391 assert textcontext is not None , "textcontext must be provided for SimpleMMDiT"
384392
385393 # 1. Patch Embedding
386- patches = self .patch_embed (x ) # Shape: [B, num_patches, emb_features]
387- num_patches = patches .shape [1 ]
388-
389- # Optional Hilbert reorder
390- hilbert_inv_idx = None
391394 if self .use_hilbert :
392- idx = hilbert_indices (H // self .patch_size , W // self .patch_size )
393- hilbert_inv_idx = inverse_permutation (
394- idx ) # Store inverse for unpatchify
395- patches = patches [:, idx , :]
395+ # Use hilbert_patchify which handles both patchification and reordering
396+ patches_raw , hilbert_inv_idx = hilbert_patchify (x , self .patch_size ) # Shape [B, S, P*P*C]
397+ # Apply projection
398+ patches = self .hilbert_proj (patches_raw ) # Shape [B, S, emb_features]
399+ else :
400+ patches = self .patch_embed (x ) # Shape: [B, num_patches, emb_features]
401+ hilbert_inv_idx = None
396402
397- x_seq = patches # Shape: [B, num_patches, emb_features]
403+ num_patches = patches .shape [1 ]
404+ x_seq = patches
398405
399406 # 2. Prepare Conditioning Signals
400407 t_emb = self .time_embed (temb ) # Shape: [B, emb_features]
@@ -419,21 +426,35 @@ def __call__(self, x, temb, textcontext): # textcontext is required
419426 x_seq = self .final_proj (x_seq )
420427
421428 # 6. Unpatchify
422- # Optional Hilbert unorder
423- if self .use_hilbert and hilbert_inv_idx is not None :
424- x_seq = x_seq [:, hilbert_inv_idx , :]
425-
426- # Determine output channels for unpatchify
427- final_out_channels = self .output_channels * \
428- (2 if self .learn_sigma else 1 )
429-
430- # Reshape back to image space
431- # Shape: [B, H, W, C (*2 if learn_sigma)]
432- out = unpatchify (x_seq , channels = final_out_channels )
429+ if self .use_hilbert :
430+ # For Hilbert mode, we need to use the specialized unpatchify function
431+ if self .learn_sigma :
432+ # Split into mean and variance predictions
433+ x_mean , x_logvar = jnp .split (x_seq , 2 , axis = - 1 )
434+ x_image = hilbert_unpatchify (x_mean , hilbert_inv_idx , self .patch_size , H , W , self .output_channels )
435+ # If needed, also unpack the logvar
436+ # logvar_image = hilbert_unpatchify(x_logvar, hilbert_inv_idx, self.patch_size, H, W, self.output_channels)
437+ # return x_image, logvar_image
438+ return x_image
439+ else :
440+ x_image = hilbert_unpatchify (x_seq , hilbert_inv_idx , self .patch_size , H , W , self .output_channels )
441+ return x_image
442+ else :
443+ # Standard patch ordering - use the existing unpatchify function
444+ if self .learn_sigma :
445+ # Split into mean and variance predictions
446+ x_mean , x_logvar = jnp .split (x_seq , 2 , axis = - 1 )
447+ x = unpatchify (x_mean , channels = self .output_channels )
448+ # Return both mean and logvar if needed by the loss function
449+ # For now, just returning the mean prediction like standard diffusion models
450+ # logvar = unpatchify(x_logvar, channels=self.output_channels)
451+ # return x, logvar
452+ return x
453+ else :
454+ # Shape: [B, H, W, C]
455+ x = unpatchify (x_seq , channels = self .output_channels )
456+ return x
433457
434- # If learn_sigma is True, the output has doubled channels.
435- # The caller is responsible for splitting if needed.
436- return out
437458
438459
439460# --- Hierarchical MM-DiT components ---
@@ -573,6 +594,15 @@ def setup(self):
573594 name = "text_context_proj"
574595 )
575596
597+ # Add projection layer for Hilbert patches
598+ if self .use_hilbert :
599+ self .hilbert_proj = nn .Dense (
600+ features = self .emb_features ,
601+ dtype = self .dtype ,
602+ precision = self .precision ,
603+ name = "hilbert_projection"
604+ )
605+
576606 # Create RoPE embeddings for each stage
577607 self .ropes = [
578608 RotaryEmbedding (
0 commit comments