Skip to content

Commit 9a0d45f

Browse files
committed
Added B4 and B5 models
1 parent 6d36a35 commit 9a0d45f

File tree

5 files changed

+22
-2
lines changed

5 files changed

+22
-2
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,8 @@ venv.bak/
112112
tensorflow/
113113
example/test*
114114
*.pth*
115+
examples/imagenet/data/
116+
!examples/imagenet/data/README.md
115117

116118

117119

README.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,15 @@
11
# EfficientNet PyTorch
2+
3+
### Update (June 18, 2019)
4+
5+
The B4 and B5 models are now available. Their usage is identical to the other models:
6+
```python
7+
from efficientnet_pytorch import EfficientNet
8+
model = EfficientNet.from_pretrained('efficientnet-b4')
9+
```
10+
Upgrade the pip package with `pip install --upgrade pytorch_efficientnet`.
11+
12+
### Overview
213
This repository contains an op-for-op PyTorch reimplementation of [EfficientNet](https://arxiv.org/abs/1905.11946), along with pre-trained models and examples.
314

415
The goal of this implementation is to be simple, highly extensible, and easy to integrate into your own projects. This implementation is a work in progress -- new features are currently being implemented.

efficientnet_pytorch/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,8 @@ def get_model_params(model_name, override_params):
236236
'efficientnet-b1': 'http://storage.googleapis.com/public-models/efficientnet-b1-dbc7070a.pth',
237237
'efficientnet-b2': 'http://storage.googleapis.com/public-models/efficientnet-b2-27687264.pth',
238238
'efficientnet-b3': 'http://storage.googleapis.com/public-models/efficientnet-b3-c8376fa2.pth',
239+
'efficientnet-b4': 'http://storage.googleapis.com/public-models/efficientnet-b4-e116e8b3.pth',
240+
'efficientnet-b5': 'http://storage.googleapis.com/public-models/efficientnet-b5-586e6cc6.pth',
239241
}
240242

241243
def load_pretrained_weights(model, model_name):

examples/imagenet/main.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,11 @@ def main_worker(gpu, ngpus_per_node, args):
132132
dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
133133
world_size=args.world_size, rank=args.rank)
134134
# create model
135-
if 'efficientnet' in args.arch: # NEW
135+
if args.arch == 'efficientnet-b5':
136+
model = EfficientNet.from_name(args.arch)
137+
model.load_state_dict(torch.load('../../tf_to_pytorch/pretrained_pytorch/efficientnet-b5.pth'))
138+
print("Using pretrained b5")
139+
elif 'efficientnet' in args.arch: # NEW
136140
if args.pretrained:
137141
model = EfficientNet.from_pretrained(args.arch)
138142
print("=> using pre-trained model '{}'".format(args.arch))

tf_to_pytorch/convert_tf_to_pt/original_tf/efficientnet_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@
3030
from six.moves import xrange # pylint: disable=redefined-builtin
3131
import tensorflow as tf
3232

33-
from efficientnet_pytorch import utils
33+
#from efficientnet_pytorch import utils
34+
from original_tf import utils
3435

3536
GlobalParams = collections.namedtuple('GlobalParams', [
3637
'batch_norm_momentum', 'batch_norm_epsilon', 'dropout_rate', 'data_format',

0 commit comments

Comments
 (0)