55from deepprofiler .learning .model import DeepProfilerModel
66from deepprofiler .learning .tf2train import DeepProfilerModelV2
77from deepprofiler .imaging .augmentations import AugmentationLayer
8+ from deepprofiler .imaging .augmentations import AugmentationLayerV2
89
910
1011##################################################
@@ -19,12 +20,16 @@ def model_factory(config, dset, crop_generator, val_crop_generator, is_training)
1920 if inspect .currentframe ().f_back .f_code .co_name == 'learn_model_v2' :
2021 tf .compat .v1 .enable_v2_behavior ()
2122 tf .config .run_functions_eagerly (True )
22- return createModelClass (DeepProfilerModelV2 , config , dset , crop_generator , val_crop_generator , is_training )
23+ augmentation_base = AugmentationLayerV2 ()
24+ return createModelClass (DeepProfilerModelV2 , config , dset , crop_generator ,
25+ val_crop_generator , is_training , augmentation_base )
2326 else :
24- return createModelClass (DeepProfilerModel , config , dset , crop_generator , val_crop_generator , is_training )
27+ augmentation_base = AugmentationLayer ()
28+ return createModelClass (DeepProfilerModel , config , dset , crop_generator ,
29+ val_crop_generator , is_training , augmentation_base )
2530
2631
27- def createModelClass (base , config , dset , crop_generator , val_crop_generator , is_training ):
32+ def createModelClass (base , config , dset , crop_generator , val_crop_generator , is_training , augmentation_base ):
2833 class ModelClass (base ):
2934 def __init__ (self , config , dset , crop_generator , val_crop_generator , is_training ):
3035 super (ModelClass , self ).__init__ (config , dset , crop_generator , val_crop_generator , is_training )
@@ -105,7 +110,7 @@ def define_model(self, config, dset):
105110 if self .config ["train" ]["model" ].get ("augmentations" ) is True :
106111 model = tf .compat .v1 .keras .models .model_from_json (
107112 model .to_json (),
108- {'AugmentationLayer' : AugmentationLayer }
113+ {'AugmentationLayer' : augmentation_base }
109114 )
110115 else :
111116 model = tf .compat .v1 .keras .models .model_from_json (model .to_json ())
@@ -117,7 +122,7 @@ def define_model(self, config, dset):
117122 ## Support for ImageNet initialization
118123 def copy_pretrained_weights (self ):
119124 base_model = self .get_model (self .config , weights = "imagenet" )
120- lshift = self .feature_model .layers [1 ].name == 'augmentation_layer ' # Shift one layer to accommodate the AugmentationLayer
125+ lshift = self .feature_model .layers [1 ].name == 'augmentation_layer_1 ' # Shift one layer to accommodate the AugmentationLayer
121126
122127 # => Transfer all weights except conv1.1
123128 total_layers = len (base_model .layers )
0 commit comments