Skip to content

Commit 9d6a1a0

Browse files
added segmentation_models for further adjustments
1 parent 6db2708 commit 9d6a1a0

File tree

21 files changed

+3173
-0
lines changed

21 files changed

+3173
-0
lines changed
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
The MIT License
2+
3+
Copyright (c) 2018, Pavel Yakubovskiy
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy
6+
of this software and associated documentation files (the "Software"), to deal
7+
in the Software without restriction, including without limitation the rights
8+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
copies of the Software, and to permit persons to whom the Software is
10+
furnished to do so, subject to the following conditions:
11+
12+
The above copyright notice and this permission notice shall be included in
13+
all copies or substantial portions of the Software.
14+
15+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
21+
THE SOFTWARE.
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
import os
2+
import functools
3+
from .__version__ import __version__
4+
from . import base
5+
6+
_KERAS_FRAMEWORK_NAME = 'keras'
7+
_TF_KERAS_FRAMEWORK_NAME = 'tf.keras'
8+
9+
_DEFAULT_KERAS_FRAMEWORK = _KERAS_FRAMEWORK_NAME
10+
_KERAS_FRAMEWORK = None
11+
_KERAS_BACKEND = None
12+
_KERAS_LAYERS = None
13+
_KERAS_MODELS = None
14+
_KERAS_UTILS = None
15+
_KERAS_LOSSES = None
16+
17+
18+
def inject_global_losses(func):
19+
@functools.wraps(func)
20+
def wrapper(*args, **kwargs):
21+
kwargs['losses'] = _KERAS_LOSSES
22+
return func(*args, **kwargs)
23+
24+
return wrapper
25+
26+
27+
def inject_global_submodules(func):
28+
@functools.wraps(func)
29+
def wrapper(*args, **kwargs):
30+
kwargs['backend'] = _KERAS_BACKEND
31+
kwargs['layers'] = _KERAS_LAYERS
32+
kwargs['models'] = _KERAS_MODELS
33+
kwargs['utils'] = _KERAS_UTILS
34+
return func(*args, **kwargs)
35+
36+
return wrapper
37+
38+
39+
def filter_kwargs(func):
40+
@functools.wraps(func)
41+
def wrapper(*args, **kwargs):
42+
new_kwargs = {k: v for k, v in kwargs.items() if k in ['backend', 'layers', 'models', 'utils']}
43+
return func(*args, **new_kwargs)
44+
45+
return wrapper
46+
47+
48+
def framework():
49+
"""Return name of Segmentation Models framework"""
50+
return _KERAS_FRAMEWORK
51+
52+
53+
def set_framework(name):
54+
"""Set framework for Segmentation Models
55+
56+
Args:
57+
name (str): one of ``keras``, ``tf.keras``, case insensitive.
58+
59+
Raises:
60+
ValueError: in case of incorrect framework name.
61+
ImportError: in case framework is not installed.
62+
63+
"""
64+
name = name.lower()
65+
66+
if name == _KERAS_FRAMEWORK_NAME:
67+
import keras
68+
import efficientnet.keras # init custom objects
69+
elif name == _TF_KERAS_FRAMEWORK_NAME:
70+
from tensorflow import keras
71+
import efficientnet.tfkeras # init custom objects
72+
else:
73+
raise ValueError('Not correct module name `{}`, use `{}` or `{}`'.format(
74+
name, _KERAS_FRAMEWORK_NAME, _TF_KERAS_FRAMEWORK_NAME))
75+
76+
global _KERAS_BACKEND, _KERAS_LAYERS, _KERAS_MODELS
77+
global _KERAS_UTILS, _KERAS_LOSSES, _KERAS_FRAMEWORK
78+
79+
_KERAS_FRAMEWORK = name
80+
_KERAS_BACKEND = keras.backend
81+
_KERAS_LAYERS = keras.layers
82+
_KERAS_MODELS = keras.models
83+
_KERAS_UTILS = keras.utils
84+
_KERAS_LOSSES = keras.losses
85+
86+
# allow losses/metrics get keras submodules
87+
base.KerasObject.set_submodules(
88+
backend=keras.backend,
89+
layers=keras.layers,
90+
models=keras.models,
91+
utils=keras.utils,
92+
)
93+
94+
95+
# set default framework
96+
_framework = os.environ.get('SM_FRAMEWORK', _DEFAULT_KERAS_FRAMEWORK)
97+
try:
98+
set_framework(_framework)
99+
except ImportError:
100+
other = _TF_KERAS_FRAMEWORK_NAME if _framework == _KERAS_FRAMEWORK_NAME else _KERAS_FRAMEWORK_NAME
101+
set_framework(other)
102+
103+
print('Segmentation Models: using `{}` framework.'.format(_KERAS_FRAMEWORK))
104+
105+
# import helper modules
106+
from . import losses
107+
from . import metrics
108+
from . import utils
109+
110+
# wrap segmentation models with framework modules
111+
from .backbones.backbones_factory import Backbones
112+
from .models.unet import Unet as _Unet
113+
from .models.pspnet import PSPNet as _PSPNet
114+
from .models.linknet import Linknet as _Linknet
115+
from .models.fpn import FPN as _FPN
116+
117+
Unet = inject_global_submodules(_Unet)
118+
PSPNet = inject_global_submodules(_PSPNet)
119+
Linknet = inject_global_submodules(_Linknet)
120+
FPN = inject_global_submodules(_FPN)
121+
get_available_backbone_names = Backbones.models_names
122+
123+
124+
def get_preprocessing(name):
125+
preprocess_input = Backbones.get_preprocessing(name)
126+
# add bakcend, models, layers, utils submodules in kwargs
127+
preprocess_input = inject_global_submodules(preprocess_input)
128+
# delete other kwargs
129+
# keras-applications preprocessing raise an error if something
130+
# except `backend`, `layers`, `models`, `utils` passed in kwargs
131+
preprocess_input = filter_kwargs(preprocess_input)
132+
return preprocess_input
133+
134+
135+
__all__ = [
136+
'Unet', 'PSPNet', 'FPN', 'Linknet',
137+
'set_framework', 'framework',
138+
'get_preprocessing', 'get_available_backbone_names',
139+
'losses', 'metrics', 'utils',
140+
'__version__',
141+
]
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
VERSION = (1, 0, 1)
2+
3+
__version__ = '.'.join(map(str, VERSION))

orthophoto-segmentation-benchmark-toolkit/model_backends/segmentation_models/backbones/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
import copy
2+
import efficientnet.model as eff
3+
from classification_models.models_factory import ModelsFactory
4+
5+
from . import inception_resnet_v2 as irv2
6+
from . import inception_v3 as iv3
7+
8+
9+
class BackbonesFactory(ModelsFactory):
10+
_default_feature_layers = {
11+
12+
# List of layers to take features from backbone in the following order:
13+
# (x16, x8, x4, x2, x1) - `x4` mean that features has 4 times less spatial
14+
# resolution (Height x Width) than input image.
15+
16+
# VGG
17+
'vgg16': ('block5_conv3', 'block4_conv3', 'block3_conv3', 'block2_conv2', 'block1_conv2'),
18+
'vgg19': ('block5_conv4', 'block4_conv4', 'block3_conv4', 'block2_conv2', 'block1_conv2'),
19+
20+
# ResNets
21+
'resnet18': ('stage4_unit1_relu1', 'stage3_unit1_relu1', 'stage2_unit1_relu1', 'relu0'),
22+
'resnet34': ('stage4_unit1_relu1', 'stage3_unit1_relu1', 'stage2_unit1_relu1', 'relu0'),
23+
'resnet50': ('stage4_unit1_relu1', 'stage3_unit1_relu1', 'stage2_unit1_relu1', 'relu0'),
24+
'resnet101': ('stage4_unit1_relu1', 'stage3_unit1_relu1', 'stage2_unit1_relu1', 'relu0'),
25+
'resnet152': ('stage4_unit1_relu1', 'stage3_unit1_relu1', 'stage2_unit1_relu1', 'relu0'),
26+
27+
# ResNeXt
28+
'resnext50': ('stage4_unit1_relu1', 'stage3_unit1_relu1', 'stage2_unit1_relu1', 'relu0'),
29+
'resnext101': ('stage4_unit1_relu1', 'stage3_unit1_relu1', 'stage2_unit1_relu1', 'relu0'),
30+
31+
# Inception
32+
'inceptionv3': (228, 86, 16, 9),
33+
'inceptionresnetv2': (594, 260, 16, 9),
34+
35+
# DenseNet
36+
'densenet121': (311, 139, 51, 4),
37+
'densenet169': (367, 139, 51, 4),
38+
'densenet201': (479, 139, 51, 4),
39+
40+
# SE models
41+
'seresnet18': ('stage4_unit1_relu1', 'stage3_unit1_relu1', 'stage2_unit1_relu1', 'relu0'),
42+
'seresnet34': ('stage4_unit1_relu1', 'stage3_unit1_relu1', 'stage2_unit1_relu1', 'relu0'),
43+
'seresnet50': (246, 136, 62, 4),
44+
'seresnet101': (552, 136, 62, 4),
45+
'seresnet152': (858, 208, 62, 4),
46+
'seresnext50': (1078, 584, 254, 4),
47+
'seresnext101': (2472, 584, 254, 4),
48+
'senet154': (6884, 1625, 454, 12),
49+
50+
# Mobile Nets
51+
'mobilenet': ('conv_pw_11_relu', 'conv_pw_5_relu', 'conv_pw_3_relu', 'conv_pw_1_relu'),
52+
'mobilenetv2': ('block_13_expand_relu', 'block_6_expand_relu', 'block_3_expand_relu',
53+
'block_1_expand_relu'),
54+
55+
# EfficientNets
56+
'efficientnetb0': ('block6a_expand_activation', 'block4a_expand_activation',
57+
'block3a_expand_activation', 'block2a_expand_activation'),
58+
'efficientnetb1': ('block6a_expand_activation', 'block4a_expand_activation',
59+
'block3a_expand_activation', 'block2a_expand_activation'),
60+
'efficientnetb2': ('block6a_expand_activation', 'block4a_expand_activation',
61+
'block3a_expand_activation', 'block2a_expand_activation'),
62+
'efficientnetb3': ('block6a_expand_activation', 'block4a_expand_activation',
63+
'block3a_expand_activation', 'block2a_expand_activation'),
64+
'efficientnetb4': ('block6a_expand_activation', 'block4a_expand_activation',
65+
'block3a_expand_activation', 'block2a_expand_activation'),
66+
'efficientnetb5': ('block6a_expand_activation', 'block4a_expand_activation',
67+
'block3a_expand_activation', 'block2a_expand_activation'),
68+
'efficientnetb6': ('block6a_expand_activation', 'block4a_expand_activation',
69+
'block3a_expand_activation', 'block2a_expand_activation'),
70+
'efficientnetb7': ('block6a_expand_activation', 'block4a_expand_activation',
71+
'block3a_expand_activation', 'block2a_expand_activation'),
72+
73+
}
74+
75+
_models_update = {
76+
'inceptionresnetv2': [irv2.InceptionResNetV2, irv2.preprocess_input],
77+
'inceptionv3': [iv3.InceptionV3, iv3.preprocess_input],
78+
79+
'efficientnetb0': [eff.EfficientNetB0, eff.preprocess_input],
80+
'efficientnetb1': [eff.EfficientNetB1, eff.preprocess_input],
81+
'efficientnetb2': [eff.EfficientNetB2, eff.preprocess_input],
82+
'efficientnetb3': [eff.EfficientNetB3, eff.preprocess_input],
83+
'efficientnetb4': [eff.EfficientNetB4, eff.preprocess_input],
84+
'efficientnetb5': [eff.EfficientNetB5, eff.preprocess_input],
85+
'efficientnetb6': [eff.EfficientNetB6, eff.preprocess_input],
86+
'efficientnetb7': [eff.EfficientNetB7, eff.preprocess_input],
87+
}
88+
89+
# currently not supported
90+
_models_delete = ['resnet50v2', 'resnet101v2', 'resnet152v2',
91+
'nasnetlarge', 'nasnetmobile', 'xception']
92+
93+
@property
94+
def models(self):
95+
all_models = copy.copy(self._models)
96+
all_models.update(self._models_update)
97+
for k in self._models_delete:
98+
del all_models[k]
99+
return all_models
100+
101+
def get_backbone(self, name, *args, **kwargs):
102+
model_fn, _ = self.get(name)
103+
model = model_fn(*args, **kwargs)
104+
return model
105+
106+
def get_feature_layers(self, name, n=5):
107+
return self._default_feature_layers[name][:n]
108+
109+
def get_preprocessing(self, name):
110+
return self.get(name)[1]
111+
112+
113+
Backbones = BackbonesFactory()

0 commit comments

Comments
 (0)