1818from ..inference .inference import get_model_path , compute_scale_from_voxel_size
1919from ..inference .util import _Scaler
2020
21-
2221def mean_teacher_adaptation (
2322 name : str ,
2423 unsupervised_train_paths : Tuple [str ],
@@ -37,9 +36,13 @@ def mean_teacher_adaptation(
3736 n_iterations : int = int (1e4 ),
3837 n_samples_train : Optional [int ] = None ,
3938 n_samples_val : Optional [int ] = None ,
40- sampler : Optional [callable ] = None ,
39+ train_mask_paths : Optional [Tuple [str ]] = None ,
40+ val_mask_paths : Optional [Tuple [str ]] = None ,
41+ patch_sampler : Optional [callable ] = None ,
42+ pseudo_label_sampler : Optional [callable ] = None ,
43+ device : int = 0 ,
4144) -> None :
42- """Run domain adapation to transfer a network trained on a source domain for a supervised
45+ """Run domain adaptation to transfer a network trained on a source domain for a supervised
4346 segmentation task to perform this task on a different target domain.
4447
4548 We support different domain adaptation settings:
@@ -82,6 +85,11 @@ def mean_teacher_adaptation(
8285 based on the patch_shape and size of the volumes used for training.
8386 n_samples_val: The number of val samples per epoch. By default this will be estimated
8487 based on the patch_shape and size of the volumes used for validation.
88+ train_mask_paths: Sample masks used by the patch sampler to accept or reject patches for training.
89+ val_mask_paths: Sample masks used by the patch sampler to accept or reject patches for validation.
90+ patch_sampler: Accept or reject patches based on a condition.
91+ pseudo_label_sampler: Mask out regions of the pseudo labels where the teacher is not confident before updating the gradients.
92+ device: GPU ID for training.
8593 """
8694 assert (supervised_train_paths is None ) == (supervised_val_paths is None )
8795 is_2d , _ = _determine_ndim (patch_shape )
@@ -97,7 +105,7 @@ def mean_teacher_adaptation(
97105 model = get_3d_model (out_channels = 2 )
98106 reinit_teacher = True
99107 else :
100- print ("Mean teacehr training initialized from source model:" , source_checkpoint )
108+ print ("Mean teacher training initialized from source model:" , source_checkpoint )
101109 if os .path .isdir (source_checkpoint ):
102110 model = torch_em .util .load_model (source_checkpoint )
103111 else :
@@ -111,12 +119,24 @@ def mean_teacher_adaptation(
111119 pseudo_labeler = self_training .DefaultPseudoLabeler (confidence_threshold = confidence_threshold )
112120 loss = self_training .DefaultSelfTrainingLoss ()
113121 loss_and_metric = self_training .DefaultSelfTrainingLossAndMetric ()
114-
122+
115123 unsupervised_train_loader = get_unsupervised_loader (
116- unsupervised_train_paths , raw_key , patch_shape , batch_size , n_samples = n_samples_train
124+ data_paths = unsupervised_train_paths ,
125+ raw_key = raw_key ,
126+ patch_shape = patch_shape ,
127+ batch_size = batch_size ,
128+ n_samples = n_samples_train ,
129+ sample_mask_paths = train_mask_paths ,
130+ sampler = patch_sampler
117131 )
118132 unsupervised_val_loader = get_unsupervised_loader (
119- unsupervised_val_paths , raw_key , patch_shape , batch_size , n_samples = n_samples_val
133+ data_paths = unsupervised_val_paths ,
134+ raw_key = raw_key ,
135+ patch_shape = patch_shape ,
136+ batch_size = batch_size ,
137+ n_samples = n_samples_val ,
138+ sample_mask_paths = val_mask_paths ,
139+ sampler = patch_sampler
120140 )
121141
122142 if supervised_train_paths is not None :
@@ -133,7 +153,7 @@ def mean_teacher_adaptation(
133153 supervised_train_loader = None
134154 supervised_val_loader = None
135155
136- device = torch .device ("cuda" ) if torch .cuda .is_available () else torch .device ("cpu" )
156+ device = torch .device (f "cuda: { device } " ) if torch .cuda .is_available () else torch .device ("cpu" )
137157 trainer = self_training .MeanTeacherTrainer (
138158 name = name ,
139159 model = model ,
@@ -155,11 +175,11 @@ def mean_teacher_adaptation(
155175 device = device ,
156176 reinit_teacher = reinit_teacher ,
157177 save_root = save_root ,
158- sampler = sampler ,
178+ sampler = pseudo_label_sampler ,
159179 )
160180 trainer .fit (n_iterations )
161-
162-
181+
182+
163183# TODO patch shapes for other models
164184PATCH_SHAPES = {
165185 "vesicles_3d" : [48 , 256 , 256 ],
@@ -228,7 +248,6 @@ def _parse_patch_shape(patch_shape, model_name):
228248 patch_shape = PATCH_SHAPES [model_name ]
229249 return patch_shape
230250
231-
232251def main ():
233252 """@private
234253 """
@@ -293,4 +312,4 @@ def main():
293312 n_samples_train = args .n_samples_train ,
294313 n_samples_val = args .n_samples_val ,
295314 check = args .check ,
296- )
315+ )
0 commit comments