Skip to content

Commit ee826d4

Browse files
committed
Merge branch 'dropout_update' into relu_update
2 parents 31f50ed + d4aed28 commit ee826d4

File tree

2 files changed

+113
-5
lines changed

2 files changed

+113
-5
lines changed

efficientnet_pytorch/model.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,8 @@ def __init__(self, blocks_args=None, global_params=None):
150150
self._bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
151151

152152
# Final linear layer
153-
self._dropout = self._global_params.dropout_rate
153+
self._avg_pooling = nn.AdaptiveAvgPool2d(1)
154+
self._dropout = nn.Dropout(self._global_params.dropout_rate)
154155
self._fc = nn.Linear(out_channels, self._global_params.num_classes)
155156

156157
def extract_features(self, inputs):
@@ -173,14 +174,14 @@ def extract_features(self, inputs):
173174

174175
def forward(self, inputs):
175176
""" Calls extract_features to extract features, applies final linear layer, and returns logits. """
176-
177+
bs = inputs.size(0)
177178
# Convolution layers
178179
x = self.extract_features(inputs)
179180

180181
# Pooling and final linear layer
181-
x = F.adaptive_avg_pool2d(x, 1).squeeze(-1).squeeze(-1)
182-
if self._dropout:
183-
x = F.dropout(x, p=self._dropout, training=self.training)
182+
x = self._avg_pooling(x)
183+
x = x.view(bs, -1)
184+
x = self._dropout(x)
184185
x = self._fc(x)
185186
return x
186187

tests/test_model.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
from collections import OrderedDict
2+
3+
import pytest
4+
import torch
5+
import torch.nn as nn
6+
7+
from efficientnet_pytorch import EfficientNet
8+
9+
10+
# -- fixtures -------------------------------------------------------------------------------------
11+
12+
@pytest.fixture(scope='module', params=[x for x in range(4)])
13+
def model(request):
14+
return 'efficientnet-b{}'.format(request.param)
15+
16+
17+
@pytest.fixture(scope='module', params=[True, False])
18+
def pretrained(request):
19+
return request.param
20+
21+
22+
@pytest.fixture(scope='function')
23+
def net(model, pretrained):
24+
return EfficientNet.from_pretrained(model) if pretrained else EfficientNet.from_name(model)
25+
26+
27+
# -- tests ----------------------------------------------------------------------------------------
28+
29+
@pytest.mark.parametrize('img_size', [224, 256, 512])
30+
def test_forward(net, img_size):
31+
"""Test `.forward()` doesn't throw an error"""
32+
data = torch.zeros((1, 3, img_size, img_size))
33+
output = net(data)
34+
assert not torch.isnan(output).any()
35+
36+
37+
def test_dropout_training(net):
38+
"""Test dropout `.training` is set by `.train()` on parent `nn.module`"""
39+
net.train()
40+
assert net._dropout.training == True
41+
42+
43+
def test_dropout_eval(net):
44+
"""Test dropout `.training` is set by `.eval()` on parent `nn.module`"""
45+
net.eval()
46+
assert net._dropout.training == False
47+
48+
49+
def test_dropout_update(net):
50+
"""Test dropout `.training` is updated by `.train()` and `.eval()` on parent `nn.module`"""
51+
net.train()
52+
assert net._dropout.training == True
53+
net.eval()
54+
assert net._dropout.training == False
55+
net.train()
56+
assert net._dropout.training == True
57+
net.eval()
58+
assert net._dropout.training == False
59+
60+
61+
@pytest.mark.parametrize('img_size', [224, 256, 512])
62+
def test_modify_dropout(net, img_size):
63+
"""Test ability to modify dropout and fc modules of network"""
64+
dropout = nn.Sequential(OrderedDict([
65+
('_bn2', nn.BatchNorm1d(net._bn1.num_features)),
66+
('_drop1', nn.Dropout(p=net._global_params.dropout_rate)),
67+
('_linear1', nn.Linear(net._bn1.num_features, 512)),
68+
('_relu', nn.ReLU()),
69+
('_bn3', nn.BatchNorm1d(512)),
70+
('_drop2', nn.Dropout(p=net._global_params.dropout_rate / 2))
71+
]))
72+
fc = nn.Linear(512, net._global_params.num_classes)
73+
74+
net._dropout = dropout
75+
net._fc = fc
76+
77+
data = torch.zeros((2, 3, img_size, img_size))
78+
output = net(data)
79+
assert not torch.isnan(output).any()
80+
81+
82+
@pytest.mark.parametrize('img_size', [224, 256, 512])
83+
def test_modify_pool(net, img_size):
84+
"""Test ability to modify pooling module of network"""
85+
86+
class AdaptiveMaxAvgPool(nn.Module):
87+
88+
def __init__(self):
89+
super().__init__()
90+
self.ada_avgpool = nn.AdaptiveAvgPool2d(1)
91+
self.ada_maxpool = nn.AdaptiveMaxPool2d(1)
92+
93+
def forward(self, x):
94+
avg_x = self.ada_avgpool(x)
95+
max_x = self.ada_maxpool(x)
96+
x = torch.cat((avg_x, max_x), dim=1)
97+
return x
98+
99+
avg_pooling = AdaptiveMaxAvgPool()
100+
fc = nn.Linear(net._fc.in_features * 2, net._global_params.num_classes)
101+
102+
net._avg_pooling = avg_pooling
103+
net._fc = fc
104+
105+
data = torch.zeros((2, 3, img_size, img_size))
106+
output = net(data)
107+
assert not torch.isnan(output).any()

0 commit comments

Comments
 (0)