Skip to content

Commit f938df2

Browse files
committed
Added test for modifying pooling in final layers of network
1 parent 4f1cf2f commit f938df2

File tree

2 files changed

+34
-5
lines changed

2 files changed

+34
-5
lines changed

efficientnet_pytorch/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,13 +174,13 @@ def extract_features(self, inputs):
174174

175175
def forward(self, inputs):
176176
""" Calls extract_features to extract features, applies final linear layer, and returns logits. """
177-
177+
bs = inputs.size(0)
178178
# Convolution layers
179179
x = self.extract_features(inputs)
180180

181181
# Pooling and final linear layer
182182
x = self._avg_pooling(x)
183-
x = x.view(x.size(0), -1)
183+
x = x.view(bs, -1)
184184
x = self._dropout(x)
185185
x = self._fc(x)
186186
return x

tests/test_model.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def pretrained(request):
1919
return request.param
2020

2121

22-
@pytest.fixture(scope='module')
22+
@pytest.fixture(scope='function')
2323
def net(model, pretrained):
2424
return EfficientNet.from_pretrained(model) if pretrained else EfficientNet.from_name(model)
2525

@@ -59,8 +59,8 @@ def test_dropout_update(net):
5959

6060

6161
@pytest.mark.parametrize('img_size', [224, 256, 512])
62-
def test_modify_head(net, img_size):
63-
"""Test ability to modify final layers of network"""
62+
def test_modify_dropout(net, img_size):
63+
"""Test ability to modify dropout and fc modules of network"""
6464
dropout = nn.Sequential(OrderedDict([
6565
('_bn2', nn.BatchNorm1d(net._bn1.num_features)),
6666
('_drop1', nn.Dropout(p=net._global_params.dropout_rate)),
@@ -77,3 +77,32 @@ def test_modify_head(net, img_size):
7777
data = torch.zeros((2, 3, img_size, img_size))
7878
output = net(data)
7979
assert not torch.isnan(output).any()
80+
81+
82+
@pytest.mark.parametrize('img_size', [224, 256, 512])
83+
def test_modify_norm(net, img_size):
84+
"""Test ability to modify norm layer of network"""
85+
86+
class AdaptiveMaxAvgPool(nn.Module):
87+
88+
def __init__(self):
89+
super().__init__()
90+
self.ada_avgpool = nn.AdaptiveAvgPool2d(1)
91+
self.ada_maxpool = nn.AdaptiveMaxPool2d(1)
92+
93+
def forward(self, x):
94+
avg_x = self.ada_avgpool(x)
95+
max_x = self.ada_maxpool(x)
96+
x = torch.cat((avg_x, max_x), dim=1)
97+
x = x.view(x.size(0), -1)
98+
return x
99+
100+
avg_pooling = AdaptiveMaxAvgPool()
101+
fc = nn.Linear(net._fc.in_features * 2, net._global_params.num_classes)
102+
103+
net._avg_pooling = avg_pooling
104+
net._fc = fc
105+
106+
data = torch.zeros((2, 3, img_size, img_size))
107+
output = net(data)
108+
assert not torch.isnan(output).any()

0 commit comments

Comments
 (0)