Skip to content

Commit 3676f4e

Browse files
committed
Added support for non-RGB images
1 parent b6a1be9 commit 3676f4e

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

efficientnet_pytorch/model.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,10 +190,21 @@ def from_name(cls, model_name, override_params=None):
190190
blocks_args, global_params = get_model_params(model_name, override_params)
191191
return cls(blocks_args, global_params)
192192

193+
@classmethod
194+
def from_pretrained(cls, model_name, num_classes=1000, in_channels = 3):
195+
model = cls.from_name(model_name, override_params={'num_classes': num_classes})
196+
load_pretrained_weights(model, model_name, load_fc=(num_classes == 1000))
197+
if in_channels != 3:
198+
Conv2d = get_same_padding_conv2d(image_size = model._global_params.image_size)
199+
out_channels = round_filters(32, model._global_params)
200+
model._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
201+
return model
202+
193203
@classmethod
194204
def from_pretrained(cls, model_name, num_classes=1000):
195205
model = cls.from_name(model_name, override_params={'num_classes': num_classes})
196206
load_pretrained_weights(model, model_name, load_fc=(num_classes == 1000))
207+
197208
return model
198209

199210
@classmethod

0 commit comments

Comments
 (0)