-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathefficient_net_model.py
More file actions
executable file
·29 lines (22 loc) · 1 KB
/
efficient_net_model.py
File metadata and controls
executable file
·29 lines (22 loc) · 1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from efficientnet_pytorch import EfficientNet
import torch
class EfficientNetModel(torch.nn.Module):
"""
EfficientNet model used for training/evaluation. Slightly modified
to account for single class + 6 input channels
"""
def __init__(self, in_channels, out_channels, out_dim, state_dict=None):
super(EfficientNetModel, self).__init__()
self._model = EfficientNet.from_pretrained('efficientnet-b0')
if not state_dict is None:
self._model.load_state_dict(state_dict)
# Define custom layers
self.out_layer = torch.nn.Linear(1000, out_dim)
self.down_channel_layer = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
self.seq = torch.nn.Sequential(self.down_channel_layer,
self._model,
self.out_layer)
return
def forward(self, x):
x = self.seq(x)
return x