Skip to content

Commit 1ca2cb5

Browse files
committed
Augmentation layer choice between trainings.
1 parent 106d188 commit 1ca2cb5

File tree

2 files changed

+20
-10
lines changed

2 files changed

+20
-10
lines changed

plugins/models/efficientnet.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import efficientnet.tfkeras as efn
55

66
from deepprofiler.imaging.augmentations import AugmentationLayer
7+
from deepprofiler.imaging.augmentations import AugmentationLayerV2
78
from deepprofiler.learning.model import DeepProfilerModel
89
from deepprofiler.learning.tf2train import DeepProfilerModelV2
910

@@ -12,12 +13,16 @@ def model_factory(config, dset, crop_generator, val_crop_generator, is_training)
1213
if inspect.currentframe().f_back.f_code.co_name == 'learn_model_v2':
1314
tf.compat.v1.enable_v2_behavior()
1415
tf.config.run_functions_eagerly(True)
15-
return createModelClass(DeepProfilerModelV2, config, dset, crop_generator, val_crop_generator, is_training)
16+
augmentation_base = AugmentationLayerV2()
17+
return createModelClass(DeepProfilerModelV2, config, dset, crop_generator,
18+
val_crop_generator, is_training, augmentation_base)
1619
else:
17-
return createModelClass(DeepProfilerModel, config, dset, crop_generator, val_crop_generator, is_training)
20+
augmentation_base = AugmentationLayer()
21+
return createModelClass(DeepProfilerModel, config, dset, crop_generator,
22+
val_crop_generator, is_training, augmentation_base)
1823

1924

20-
def createModelClass(base, config, dset, crop_generator, val_crop_generator, is_training):
25+
def createModelClass(base, config, dset, crop_generator, val_crop_generator, is_training, augmentation_base):
2126
class ModelClass(base):
2227
def __init__(self, config, dset, crop_generator, val_crop_generator, is_training):
2328
super(ModelClass, self).__init__(config, dset, crop_generator, val_crop_generator, is_training)
@@ -99,7 +104,7 @@ def define_model(self, config, dset):
99104
if self.config["train"]["model"].get("augmentations") is True:
100105
model = tf.compat.v1.keras.models.model_from_json(
101106
model.to_json(),
102-
{'AugmentationLayer': AugmentationLayer}
107+
{'AugmentationLayer': augmentation_base}
103108
)
104109
else:
105110
model = tf.compat.v1.keras.models.model_from_json(model.to_json())
@@ -108,7 +113,7 @@ def define_model(self, config, dset):
108113

109114
def copy_pretrained_weights(self):
110115
base_model = self.get_model(self.config, weights="imagenet")
111-
lshift = self.feature_model.layers[1].name == 'augmentation_layer' # Shift one layer to accommodate the AugmentationLayer
116+
lshift = self.feature_model.layers[1].name == 'augmentation_layer_1' # Shift one layer to accommodate the AugmentationLayer
112117

113118
# => Transfer all weights except conv1.1
114119
total_layers = len(base_model.layers)

plugins/models/resnet.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from deepprofiler.learning.model import DeepProfilerModel
66
from deepprofiler.learning.tf2train import DeepProfilerModelV2
77
from deepprofiler.imaging.augmentations import AugmentationLayer
8+
from deepprofiler.imaging.augmentations import AugmentationLayerV2
89

910

1011
##################################################
@@ -19,12 +20,16 @@ def model_factory(config, dset, crop_generator, val_crop_generator, is_training)
1920
if inspect.currentframe().f_back.f_code.co_name == 'learn_model_v2':
2021
tf.compat.v1.enable_v2_behavior()
2122
tf.config.run_functions_eagerly(True)
22-
return createModelClass(DeepProfilerModelV2, config, dset, crop_generator, val_crop_generator, is_training)
23+
augmentation_base = AugmentationLayerV2()
24+
return createModelClass(DeepProfilerModelV2, config, dset, crop_generator,
25+
val_crop_generator, is_training, augmentation_base)
2326
else:
24-
return createModelClass(DeepProfilerModel, config, dset, crop_generator, val_crop_generator, is_training)
27+
augmentation_base = AugmentationLayer()
28+
return createModelClass(DeepProfilerModel, config, dset, crop_generator,
29+
val_crop_generator, is_training, augmentation_base)
2530

2631

27-
def createModelClass(base, config, dset, crop_generator, val_crop_generator, is_training):
32+
def createModelClass(base, config, dset, crop_generator, val_crop_generator, is_training, augmentation_base):
2833
class ModelClass(base):
2934
def __init__(self, config, dset, crop_generator, val_crop_generator, is_training):
3035
super(ModelClass, self).__init__(config, dset, crop_generator, val_crop_generator, is_training)
@@ -105,7 +110,7 @@ def define_model(self, config, dset):
105110
if self.config["train"]["model"].get("augmentations") is True:
106111
model = tf.compat.v1.keras.models.model_from_json(
107112
model.to_json(),
108-
{'AugmentationLayer': AugmentationLayer}
113+
{'AugmentationLayer': augmentation_base}
109114
)
110115
else:
111116
model = tf.compat.v1.keras.models.model_from_json(model.to_json())
@@ -117,7 +122,7 @@ def define_model(self, config, dset):
117122
## Support for ImageNet initialization
118123
def copy_pretrained_weights(self):
119124
base_model = self.get_model(self.config, weights="imagenet")
120-
lshift = self.feature_model.layers[1].name == 'augmentation_layer' # Shift one layer to accommodate the AugmentationLayer
125+
lshift = self.feature_model.layers[1].name == 'augmentation_layer_1' # Shift one layer to accommodate the AugmentationLayer
121126

122127
# => Transfer all weights except conv1.1
123128
total_layers = len(base_model.layers)

0 commit comments

Comments
 (0)