Skip to content

Commit 88f9867

Browse files
committed
fix: maybe fixed finally?
1 parent f1f6836 commit 88f9867

File tree

1 file changed

+17
-11
lines changed

1 file changed

+17
-11
lines changed

flaxdiff/models/simple_mmdit.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -644,8 +644,8 @@ def setup(self):
644644
]
645645

646646
# 5. --- Encoder Path (Fine to Coarse) ---
647-
self.encoder_blocks = []
648-
self.patch_mergers = []
647+
encoder_blocks = []
648+
patch_mergers = []
649649
for stage in range(num_stages):
650650
# Blocks for this stage
651651
stage_blocks = [
@@ -665,11 +665,11 @@ def setup(self):
665665
# Assuming symmetric layers for now, adjust if needed (e.g., self.num_encoder_layers[stage])
666666
for i in range(self.num_layers[stage])
667667
]
668-
self.encoder_blocks.append(stage_blocks)
668+
encoder_blocks.append(stage_blocks)
669669

670670
# Patch Merging layer (except for the last/coarsest stage)
671671
if stage < num_stages - 1:
672-
self.patch_mergers.append(
672+
patch_mergers.append(
673673
PatchMerging(
674674
out_features=self.emb_features[stage + 1], # Target next stage dim
675675
dtype=self.dtype,
@@ -678,15 +678,17 @@ def setup(self):
678678
name=f"patch_merger_{stage}"
679679
)
680680
)
681-
681+
self.encoder_blocks = encoder_blocks
682+
self.patch_mergers = patch_mergers
683+
682684
# 6. --- Decoder Path (Coarse to Fine) ---
683-
self.decoder_blocks = []
684-
self.patch_expanders = []
685-
self.fusion_layers = []
685+
decoder_blocks = []
686+
patch_expanders = []
687+
fusion_layers = []
686688
# Iterate from second coarsest stage (N-2) down to finest (0)
687689
for stage in range(num_stages - 2, -1, -1):
688690
# Patch Expanding layer (Expands from stage+1 to stage)
689-
self.patch_expanders.append(
691+
patch_expanders.append(
690692
PatchExpanding(
691693
out_features=self.emb_features[stage], # Target current stage dim
692694
dtype=self.dtype,
@@ -696,7 +698,7 @@ def setup(self):
696698
)
697699
)
698700
# Fusion layer (Combines skip[stage] and expanded[stage+1]->[stage])
699-
self.fusion_layers.append(
701+
fusion_layers.append(
700702
nn.Sequential([ # Use Sequential for Norm + Dense
701703
nn.LayerNorm(epsilon=self.norm_epsilon, dtype=self.dtype, name=f"fusion_norm_{stage}"),
702704
nn.Dense(
@@ -727,8 +729,12 @@ def setup(self):
727729
for i in range(self.num_layers[stage])
728730
]
729731
# Append blocks in order: stage N-2, N-3, ..., 0
730-
self.decoder_blocks.append(stage_blocks)
732+
decoder_blocks.append(stage_blocks)
731733

734+
self.patch_expanders = patch_expanders
735+
self.fusion_layers = fusion_layers
736+
self.decoder_blocks = decoder_blocks
737+
732738
# Note: The lists expanders, fusion_layers, decoder_blocks are now ordered
733739
# corresponding to stages N-2, N-3, ..., 0.
734740

0 commit comments

Comments
 (0)