diff --git a/pretrainedmodels/models/xception.py b/pretrainedmodels/models/xception.py index 7783c477..c362a557 100644 --- a/pretrainedmodels/models/xception.py +++ b/pretrainedmodels/models/xception.py @@ -217,12 +217,12 @@ def xception(num_classes=1000, pretrained='imagenet'): model = Xception(num_classes=num_classes) if pretrained: settings = pretrained_settings['xception'][pretrained] - assert num_classes == settings['num_classes'], \ - "num_classes should be {}, but is {}".format(settings['num_classes'], num_classes) - - model = Xception(num_classes=num_classes) + model = Xception(num_classes=settings['num_classes']) model.load_state_dict(model_zoo.load_url(settings['url'])) - + in_features = model.fc.in_features + del model.fc + model.fc = nn.Linear(in_features, num_classes) + model.input_space = settings['input_space'] model.input_size = settings['input_size'] model.input_range = settings['input_range']