Skip to content

Commit d1ce4a0

Browse files
committed
improve test coverage
1 parent 9c62b72 commit d1ce4a0

File tree

1 file changed

+56
-10
lines changed

1 file changed

+56
-10
lines changed

tests/models/test_arch_timm_efficientnet.py

Lines changed: 56 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,24 @@ def test_encoder_mixin_properties_and_set_in_channels() -> None:
115115
assert encoder.conv.in_channels == 5
116116

117117

118+
def test_set_in_channels_noop_for_default() -> None:
119+
"""Calling with DEFAULT_IN_CHANNELS should skip patching."""
120+
encoder = DummyEncoder()
121+
encoder.set_in_channels(DEFAULT_IN_CHANNELS, pretrained=True)
122+
assert encoder._in_channels == DEFAULT_IN_CHANNELS
123+
124+
125+
def test_set_in_channels_modify_out_channels() -> None:
126+
"""First output channels should change when in_channels is modified."""
127+
encoder = DummyEncoder()
128+
encoder._out_channels[0] = DEFAULT_IN_CHANNELS
129+
130+
encoder.set_in_channels(5, pretrained=False)
131+
132+
assert encoder._out_channels[0] == 5
133+
assert encoder._in_channels == 5
134+
135+
118136
def test_encoder_mixin_make_dilated_and_validation() -> None:
119137
"""make_dilated should error on invalid stride and patch convs otherwise."""
120138
encoder = DummyEncoder()
@@ -132,6 +150,38 @@ def test_encoder_mixin_make_dilated_and_validation() -> None:
132150
assert conv32.dilation == (4, 4)
133151

134152

153+
def test_make_dilated_skips_stages_below_output_stride() -> None:
154+
"""Stages at or below the target stride should be left untouched."""
155+
encoder = DummyEncoder()
156+
encoder.conv.stride = (2, 2) # stage_stride == 16, so should be skipped
157+
encoder.conv.dilation = (1, 1)
158+
159+
encoder.make_dilated(output_stride=16)
160+
161+
# stage at stride 16 skipped
162+
assert encoder.conv.stride == (2, 2)
163+
assert encoder.conv.dilation == (1, 1)
164+
165+
# stage at stride 32 modified
166+
conv32 = encoder.get_stages()[32][0]
167+
assert conv32.dilation == (2, 2)
168+
assert conv32.padding == (2, 2)
169+
170+
171+
def test_efficientnet_encoder_get_stages_splits_blocks() -> None:
172+
"""Test get_stages for dilation modification."""
173+
encoder = EfficientNetEncoder(
174+
stage_idxs=[1, 2, 4],
175+
out_channels=[3, 8, 16, 32, 64, 128],
176+
depth=3,
177+
channel_multiplier=1.0,
178+
depth_multiplier=1.0,
179+
)
180+
stages = encoder.get_stages()
181+
assert len(stages) == 2
182+
assert stages.keys() == {16, 32}
183+
184+
135185
def test_get_efficientnet_kwargs_shapes_and_values() -> None:
136186
"""get_efficientnet_kwargs should produce expected keys and scaling."""
137187
# confirm output contains decoded blocks and scaled channels
@@ -160,18 +210,14 @@ def test_efficientnet_encoder_depth_validation_and_forward() -> None:
160210
stage_idxs=[2, 3, 5],
161211
out_channels=[3, 32, 24, 40, 112, 320],
162212
depth=3,
163-
channel_multiplier=0.5,
164-
depth_multiplier=0.5,
213+
channel_multiplier=1.0,
214+
depth_multiplier=1.0,
165215
)
166216
x = torch.randn(1, 3, 32, 32)
167217
features = encoder(x)
168218
assert len(features) == encoder._depth + 1
169219
assert torch.equal(features[0], x)
170-
171-
# ensure classifier keys are dropped before loading into the model
172-
extended_state = dict(encoder.state_dict())
173-
extended_state["classifier.bias"] = torch.tensor([1.0])
174-
extended_state["classifier.weight"] = torch.tensor([[1.0]])
175-
load_result = encoder.load_state_dict(extended_state, strict=True)
176-
assert not load_result.missing_keys
177-
assert not load_result.unexpected_keys
220+
# cover depth-gated forward branches up to depth 3
221+
assert features[1].shape[1] == 32
222+
assert features[2].shape[1] == 24
223+
assert features[3].shape[1] == 40

0 commit comments

Comments
 (0)