@@ -644,8 +644,8 @@ def setup(self):
644644 ]
645645
646646 # 5. --- Encoder Path (Fine to Coarse) ---
647- self . encoder_blocks = []
648- self . patch_mergers = []
647+ encoder_blocks = []
648+ patch_mergers = []
649649 for stage in range (num_stages ):
650650 # Blocks for this stage
651651 stage_blocks = [
@@ -665,11 +665,11 @@ def setup(self):
665665 # Assuming symmetric layers for now, adjust if needed (e.g., self.num_encoder_layers[stage])
666666 for i in range (self .num_layers [stage ])
667667 ]
668- self . encoder_blocks .append (stage_blocks )
668+ encoder_blocks .append (stage_blocks )
669669
670670 # Patch Merging layer (except for the last/coarsest stage)
671671 if stage < num_stages - 1 :
672- self . patch_mergers .append (
672+ patch_mergers .append (
673673 PatchMerging (
674674 out_features = self .emb_features [stage + 1 ], # Target next stage dim
675675 dtype = self .dtype ,
@@ -678,15 +678,17 @@ def setup(self):
678678 name = f"patch_merger_{ stage } "
679679 )
680680 )
681-
681+ self .encoder_blocks = encoder_blocks
682+ self .patch_mergers = patch_mergers
683+
682684 # 6. --- Decoder Path (Coarse to Fine) ---
683- self . decoder_blocks = []
684- self . patch_expanders = []
685- self . fusion_layers = []
685+ decoder_blocks = []
686+ patch_expanders = []
687+ fusion_layers = []
686688 # Iterate from second coarsest stage (N-2) down to finest (0)
687689 for stage in range (num_stages - 2 , - 1 , - 1 ):
688690 # Patch Expanding layer (Expands from stage+1 to stage)
689- self . patch_expanders .append (
691+ patch_expanders .append (
690692 PatchExpanding (
691693 out_features = self .emb_features [stage ], # Target current stage dim
692694 dtype = self .dtype ,
@@ -696,7 +698,7 @@ def setup(self):
696698 )
697699 )
698700 # Fusion layer (Combines skip[stage] and expanded[stage+1]->[stage])
699- self . fusion_layers .append (
701+ fusion_layers .append (
700702 nn .Sequential ([ # Use Sequential for Norm + Dense
701703 nn .LayerNorm (epsilon = self .norm_epsilon , dtype = self .dtype , name = f"fusion_norm_{ stage } " ),
702704 nn .Dense (
@@ -727,8 +729,12 @@ def setup(self):
727729 for i in range (self .num_layers [stage ])
728730 ]
729731 # Append blocks in order: stage N-2, N-3, ..., 0
730- self . decoder_blocks .append (stage_blocks )
732+ decoder_blocks .append (stage_blocks )
731733
734+ self .patch_expanders = patch_expanders
735+ self .fusion_layers = fusion_layers
736+ self .decoder_blocks = decoder_blocks
737+
732738 # Note: The lists expanders, fusion_layers, decoder_blocks are now ordered
733739 # corresponding to stages N-2, N-3, ..., 0.
734740
0 commit comments