Skip to content

Commit cb83b4f

Browse files
authored
Misleading error when trying to load models with underscore
Thanks for this repo! Quick remark: I tried passing `efficientnet_b0` but you get `KeyError: 'efficientnet_b0'` if you try to pass this. Yet, if you pass a wrong model name such as `efficientnet_0` you get the following message. ``` ValueError: model_name should be one of: efficientnet_b0, efficientnet_b1, efficientnet_b2, efficientnet_b3, efficientnet_b4, efficientnet_b5, efficientnet_b6, efficientnet_b7 ``` The error should tell you to provide a name with a dash instead. ``` ValueError: model_name should be one of: efficientnet-b0, efficientnet-b1, efficientnet-b2, efficientnet-b3, efficientnet-b4, efficientnet-b5, efficientnet-b6, efficientnet-b7 ``` I think the proposed small change fixes the issue.
1 parent e5c8726 commit cb83b4f

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

efficientnet_pytorch/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,6 @@ def _check_model_name_is_valid(cls, model_name, also_need_pretrained_weights=Fal
207207
""" Validates model name. None that pretrained weights are only available for
208208
the first four models (efficientnet-b{i} for i in 0,1,2,3) at the moment. """
209209
num_models = 4 if also_need_pretrained_weights else 8
210-
valid_models = ['efficientnet_b'+str(i) for i in range(num_models)]
211-
if model_name.replace('-','_') not in valid_models:
210+
valid_models = ['efficientnet-b'+str(i) for i in range(num_models)]
211+
if model_name not in valid_models:
212212
raise ValueError('model_name should be one of: ' + ', '.join(valid_models))

0 commit comments

Comments
 (0)