66import torch_em .self_training as self_training
77from torchvision import transforms
88
9-
10- def get_3d_model (out_channels ):
11- raise NotImplementedError
12-
13-
14- def get_supervised_loader ():
15- raise NotImplementedError
9+ from .util import get_supervised_loader , get_3d_model
1610
1711
1812def weak_augmentations (p : float = 0.75 ) -> callable :
@@ -79,15 +73,17 @@ def get_unsupervised_loader(
7973 return loader
8074
8175
82- def mean_teacher_adaptation (
76+ def mean_teacher_training (
8377 name : str ,
8478 unsupervised_train_paths : Tuple [str ],
8579 unsupervised_val_paths : Tuple [str ],
8680 patch_shape : Tuple [int , int , int ],
8781 save_root : Optional [str ] = None ,
8882 source_checkpoint : Optional [str ] = None ,
89- supervised_train_paths : Optional [Tuple [str ]] = None ,
90- supervised_val_paths : Optional [Tuple [str ]] = None ,
83+ supervised_train_image_paths : Optional [Tuple [str ]] = None ,
84+ supervised_val_image_paths : Optional [Tuple [str ]] = None ,
85+ supervised_train_label_paths : Optional [Tuple [str ]] = None ,
86+ supervised_val_label_paths : Optional [Tuple [str ]] = None ,
9187 confidence_threshold : float = 0.9 ,
9288 raw_key : Optional [str ] = None ,
9389 raw_key_supervised : Optional [str ] = None ,
@@ -99,14 +95,13 @@ def mean_teacher_adaptation(
9995 n_samples_val : Optional [int ] = None ,
10096 sampler : Optional [callable ] = None ,
10197) -> None :
102- """Run domain adapation to transfer a network trained on a source domain for a supervised
103- segmentation task to perform this task on a different target domain.
98+ """This function implements network training with a mean teacher approach.
10499
105- We support different domain adaptation settings:
106- - unsupervised domain adaptation: the default mode when 'supervised_train_paths' and
107- 'supervised_val_paths' are not given .
108- - semi-supervised domain adaptation: domain adaptation on unlabeled and labeled data,
109- when 'supervised_train_paths' and 'supervised_val_paths' are given .
100+ It can be used for semi-supervised learning, unsupervised domain adaptation and supervised domain adaptation.
101+ These different training modes can be used as this:
102+ - semi-supervised learning: pass 'unsupervised_train/val_paths' and 'supervised_train/val_paths' .
103+ - unsupervised domain adaptation: pass 'unsupervised_train/val_paths' and 'source_checkpoint'.
104+ - supervised domain adaptation: pass 'unsupervised_train/val_paths', 'supervised_train/val_paths', 'source_checkpoint' .
110105
111106 Args:
112107 name: The name for the checkpoint to be trained.
@@ -125,30 +120,38 @@ def mean_teacher_adaptation(
125120 If the checkpoint is not given, then both student and teacher model are initialized
126121 from scratch. In this case `supervised_train_paths` and `supervised_val_paths` have to
127122 be given in order to provide training data from the source domain.
128- supervised_train_paths: Filepaths to the hdf5 files for the training data in the source domain.
129- This training data is optional. If given, it is used for unsupervised learnig and requires labels.
130- supervised_val_paths: Filepaths to the df5 files for the validation data in the source domain.
131- This validation data is optional. If given, it is used for unsupervised learnig and requires labels.
123+ supervised_train_image_paths: Paths to the files for the supervised image data; training split.
124+ This training data is optional. If given, it also requires labels.
125+ supervised_val_image_paths: Ppaths to the files for the supervised image data; validation split.
126+ This validation data is optional. If given, it also requires labels.
127+ supervised_train_label_paths: Filepaths to the files for the supervised label masks; training split.
128+ This training data is optional.
129+ supervised_val_label_paths: Filepaths to the files for the supervised label masks; validation split.
130+ This tvalidation data is optional.
132131 confidence_threshold: The threshold for filtering data in the unsupervised loss.
133132 The label filtering is done based on the uncertainty of network predictions, and only
134133 the data with higher certainty than this threshold is used for training.
135- raw_key: The key that holds the raw data inside of the hdf5 or similar files.
134+ raw_key: The key that holds the raw data inside of the hdf5 or similar files;
135+ for the unsupervised training data. Set to None for tifs.
136+ raw_key_supervised: The key that holds the raw data inside of the hdf5 or similar files;
137+ for the supervised training data. Set to None for tifs.
136138 label_key: The key that holds the labels inside of the hdf5 files for supervised learning.
137- This is only required if `supervised_train_paths` and `supervised_val_paths` are given.
139+ This is only required if `supervised_train_label_paths` and `supervised_val_label_paths` are given.
140+ Set to None for tifs.
138141 batch_size: The batch size for training.
139142 lr: The initial learning rate.
140143 n_iterations: The number of iterations to train for.
141144 n_samples_train: The number of train samples per epoch. By default this will be estimated
142145 based on the patch_shape and size of the volumes used for training.
143146 n_samples_val: The number of val samples per epoch. By default this will be estimated
144147 based on the patch_shape and size of the volumes used for validation.
145- """
146- assert (supervised_train_paths is None ) == (supervised_val_paths is None )
148+ """ # noqa
149+ assert (supervised_train_image_paths is None ) == (supervised_val_image_paths is None )
147150
148151 if source_checkpoint is None :
149- # training from scratch only makes sense if we have supervised training data
152+ # Training from scratch only makes sense if we have supervised training data
150153 # that's why we have the assertion here.
151- assert supervised_train_paths is not None
154+ assert supervised_train_image_paths is not None
152155 model = get_3d_model (out_channels = 3 )
153156 reinit_teacher = True
154157 else :
@@ -174,15 +177,16 @@ def mean_teacher_adaptation(
174177 unsupervised_val_paths , raw_key , patch_shape , batch_size , n_samples = n_samples_val
175178 )
176179
177- if supervised_train_paths is not None :
178- assert label_key is not None
180+ if supervised_train_image_paths is not None :
179181 supervised_train_loader = get_supervised_loader (
180- supervised_train_paths , raw_key_supervised , label_key ,
181- patch_shape , batch_size , n_samples = n_samples_train ,
182+ supervised_train_image_paths , supervised_train_label_paths ,
183+ patch_shape = patch_shape , batch_size = batch_size , n_samples = n_samples_train ,
184+ image_key = raw_key_supervised , label_key = label_key ,
182185 )
183186 supervised_val_loader = get_supervised_loader (
184- supervised_val_paths , raw_key_supervised , label_key ,
185- patch_shape , batch_size , n_samples = n_samples_val ,
187+ supervised_val_image_paths , supervised_val_label_paths ,
188+ patch_shape = patch_shape , batch_size = batch_size , n_samples = n_samples_val ,
189+ image_key = raw_key_supervised , label_key = label_key ,
186190 )
187191 else :
188192 supervised_train_loader = None
0 commit comments