diff --git a/libs/models_keras.py b/libs/models_keras.py index 3545bcd..1556486 100644 --- a/libs/models_keras.py +++ b/libs/models_keras.py @@ -40,8 +40,7 @@ def build_unet(size=300, basef=64, maxf=512, encoder='resnet50', pretrained=True def make_encoder(input, name='resnet50', pretrained=True): if name == 'resnet18': - from classification_models.keras import Classifiers - ResNet18, _ = Classifiers.get('resnet18') + from keras.applications.resnet import ResNet18 model = ResNet18( weights='imagenet' if pretrained else None, input_tensor=input,