@@ -115,6 +115,24 @@ def test_encoder_mixin_properties_and_set_in_channels() -> None:
115115 assert encoder .conv .in_channels == 5
116116
117117
118+ def test_set_in_channels_noop_for_default () -> None :
119+ """Calling with DEFAULT_IN_CHANNELS should skip patching."""
120+ encoder = DummyEncoder ()
121+ encoder .set_in_channels (DEFAULT_IN_CHANNELS , pretrained = True )
122+ assert encoder ._in_channels == DEFAULT_IN_CHANNELS
123+
124+
125+ def test_set_in_channels_modify_out_channels () -> None :
126+ """First output channels should change when in_channels is modified."""
127+ encoder = DummyEncoder ()
128+ encoder ._out_channels [0 ] = DEFAULT_IN_CHANNELS
129+
130+ encoder .set_in_channels (5 , pretrained = False )
131+
132+ assert encoder ._out_channels [0 ] == 5
133+ assert encoder ._in_channels == 5
134+
135+
118136def test_encoder_mixin_make_dilated_and_validation () -> None :
119137 """make_dilated should error on invalid stride and patch convs otherwise."""
120138 encoder = DummyEncoder ()
@@ -132,6 +150,38 @@ def test_encoder_mixin_make_dilated_and_validation() -> None:
132150 assert conv32 .dilation == (4 , 4 )
133151
134152
153+ def test_make_dilated_skips_stages_below_output_stride () -> None :
154+ """Stages at or below the target stride should be left untouched."""
155+ encoder = DummyEncoder ()
156+ encoder .conv .stride = (2 , 2 ) # stage_stride == 16, so should be skipped
157+ encoder .conv .dilation = (1 , 1 )
158+
159+ encoder .make_dilated (output_stride = 16 )
160+
161+ # stage at stride 16 skipped
162+ assert encoder .conv .stride == (2 , 2 )
163+ assert encoder .conv .dilation == (1 , 1 )
164+
165+ # stage at stride 32 modified
166+ conv32 = encoder .get_stages ()[32 ][0 ]
167+ assert conv32 .dilation == (2 , 2 )
168+ assert conv32 .padding == (2 , 2 )
169+
170+
171+ def test_efficientnet_encoder_get_stages_splits_blocks () -> None :
172+ """Test get_stages for dilation modification."""
173+ encoder = EfficientNetEncoder (
174+ stage_idxs = [1 , 2 , 4 ],
175+ out_channels = [3 , 8 , 16 , 32 , 64 , 128 ],
176+ depth = 3 ,
177+ channel_multiplier = 1.0 ,
178+ depth_multiplier = 1.0 ,
179+ )
180+ stages = encoder .get_stages ()
181+ assert len (stages ) == 2
182+ assert stages .keys () == {16 , 32 }
183+
184+
135185def test_get_efficientnet_kwargs_shapes_and_values () -> None :
136186 """get_efficientnet_kwargs should produce expected keys and scaling."""
137187 # confirm output contains decoded blocks and scaled channels
@@ -160,18 +210,14 @@ def test_efficientnet_encoder_depth_validation_and_forward() -> None:
160210 stage_idxs = [2 , 3 , 5 ],
161211 out_channels = [3 , 32 , 24 , 40 , 112 , 320 ],
162212 depth = 3 ,
163- channel_multiplier = 0.5 ,
164- depth_multiplier = 0.5 ,
213+ channel_multiplier = 1.0 ,
214+ depth_multiplier = 1.0 ,
165215 )
166216 x = torch .randn (1 , 3 , 32 , 32 )
167217 features = encoder (x )
168218 assert len (features ) == encoder ._depth + 1
169219 assert torch .equal (features [0 ], x )
170-
171- # ensure classifier keys are dropped before loading into the model
172- extended_state = dict (encoder .state_dict ())
173- extended_state ["classifier.bias" ] = torch .tensor ([1.0 ])
174- extended_state ["classifier.weight" ] = torch .tensor ([[1.0 ]])
175- load_result = encoder .load_state_dict (extended_state , strict = True )
176- assert not load_result .missing_keys
177- assert not load_result .unexpected_keys
220+ # cover depth-gated forward branches up to depth 3
221+ assert features [1 ].shape [1 ] == 32
222+ assert features [2 ].shape [1 ] == 24
223+ assert features [3 ].shape [1 ] == 40
0 commit comments