Skip to content

Commit 28f70f8

Browse files
committed
Added ability to load pretrained model with custom number of classes
1 parent bd3c392 commit 28f70f8

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

efficientnet_pytorch/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def from_name(cls, model_name, override_params=None):
192192

193193
@classmethod
194194
def from_pretrained(cls, model_name, num_classes=1000):
195-
model = EfficientNet.from_name(model_name, override_params={'num_classes': 1000})
195+
model = EfficientNet.from_name(model_name, override_params={'num_classes': num_classes})
196196
load_pretrained_weights(model, model_name, load_fc=(num_classes == 1000))
197197
return model
198198

efficientnet_pytorch/utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -279,5 +279,11 @@ def get_model_params(model_name, override_params):
279279
def load_pretrained_weights(model, model_name, load_fc=True):
280280
""" Loads pretrained weights, and downloads if loading for the first time. """
281281
state_dict = model_zoo.load_url(url_map[model_name])
282-
model.load_state_dict(state_dict)
282+
if load_fc:
283+
model.load_state_dict(state_dict)
284+
else:
285+
state_dict.pop('_fc.weight')
286+
state_dict.pop('_fc.bias')
287+
res = model.load_state_dict(state_dict, strict=False)
288+
assert str(res.missing_keys) == str(['_fc.weight', '_fc.bias']), 'issue loading pretrained weights'
283289
print('Loaded pretrained weights for {}'.format(model_name))

0 commit comments

Comments
 (0)