Skip to content

Commit 396b06b

Browse files
committed
Add advprop and switch hosting providers
1 parent e22b46e commit 396b06b

File tree

2 files changed

+27
-19
lines changed

2 files changed

+27
-19
lines changed

efficientnet_pytorch/model.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -206,22 +206,15 @@ def from_name(cls, model_name, override_params=None):
206206
return cls(blocks_args, global_params)
207207

208208
@classmethod
209-
def from_pretrained(cls, model_name, num_classes=1000, in_channels = 3):
209+
def from_pretrained(cls, model_name, advprop=False, num_classes=1000, in_channels=3):
210210
model = cls.from_name(model_name, override_params={'num_classes': num_classes})
211-
load_pretrained_weights(model, model_name, load_fc=(num_classes == 1000))
211+
load_pretrained_weights(model, model_name, load_fc=(num_classes == 1000), advprop=advprop)
212212
if in_channels != 3:
213213
Conv2d = get_same_padding_conv2d(image_size = model._global_params.image_size)
214214
out_channels = round_filters(32, model._global_params)
215215
model._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
216216
return model
217217

218-
@classmethod
219-
def from_pretrained(cls, model_name, num_classes=1000):
220-
model = cls.from_name(model_name, override_params={'num_classes': num_classes})
221-
load_pretrained_weights(model, model_name, load_fc=(num_classes == 1000))
222-
223-
return model
224-
225218
@classmethod
226219
def get_image_size(cls, model_name):
227220
cls._check_model_name_is_valid(model_name)

efficientnet_pytorch/utils.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -295,20 +295,35 @@ def get_model_params(model_name, override_params):
295295
return blocks_args, global_params
296296

297297

298-
url_map = {
299-
'efficientnet-b0': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b0-355c32eb.pth',
300-
'efficientnet-b1': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b1-f1951068.pth',
301-
'efficientnet-b2': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b2-8bb594d6.pth',
302-
'efficientnet-b3': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b3-5fb5a3c3.pth',
303-
'efficientnet-b4': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b4-6ed6700e.pth',
304-
'efficientnet-b5': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b5-b6417697.pth',
305-
'efficientnet-b6': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b6-c76e70fd.pth',
306-
'efficientnet-b7': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b7-dcc49843.pth',
298+
url_map_aa = {
299+
'efficientnet-b0': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b0-355c32eb.pth',
300+
'efficientnet-b1': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b1-f1951068.pth',
301+
'efficientnet-b2': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b2-8bb594d6.pth',
302+
'efficientnet-b3': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b3-5fb5a3c3.pth',
303+
'efficientnet-b4': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b4-6ed6700e.pth',
304+
'efficientnet-b5': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b5-b6417697.pth',
305+
'efficientnet-b6': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b6-c76e70fd.pth',
306+
'efficientnet-b7': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b7-dcc49843.pth',
307307
}
308308

309309

310-
def load_pretrained_weights(model, model_name, load_fc=True):
310+
url_map_advprop = {
311+
'efficientnet-b0': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b0-b64d5a18.pth',
312+
'efficientnet-b1': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b1-0f3ce85a.pth',
313+
'efficientnet-b2': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b2-6e9d97e5.pth',
314+
'efficientnet-b3': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b3-cdd7c0f4.pth',
315+
'efficientnet-b4': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b4-44fb3a87.pth',
316+
'efficientnet-b5': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b5-86493f6b.pth',
317+
'efficientnet-b6': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b6-ac80338e.pth',
318+
'efficientnet-b7': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b7-4652b6dd.pth',
319+
'efficientnet-b8': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b8-22a8fe65.pth',
320+
}
321+
322+
323+
def load_pretrained_weights(model, model_name, load_fc=True, advprop=False):
311324
""" Loads pretrained weights, and downloads if loading for the first time. """
325+
# AutoAugment or Advprop (different preprocessing)
326+
url_map = url_map_advprop if advprop else url_map_aa
312327
state_dict = model_zoo.load_url(url_map[model_name])
313328
if load_fc:
314329
model.load_state_dict(state_dict)

0 commit comments

Comments
 (0)