11import os
2- import tempfile
3- from glob import glob
4- from pathlib import Path
52from typing import Optional , Tuple
63
7- import mrcfile
84import torch
95import torch_em
106import torch_em .self_training as self_training
11- from elf .io import open_file
12- from sklearn .model_selection import train_test_split
137
148from .semisupervised_training import get_unsupervised_loader
15- from .supervised_training import (
16- get_2d_model , get_3d_model , get_supervised_loader , _determine_ndim , _derive_key_from_files
17- )
18- from ..inference .inference import get_model_path , compute_scale_from_voxel_size
19- from ..inference .util import _Scaler
9+ from .supervised_training import get_2d_model , get_3d_model , get_supervised_loader , _determine_ndim
10+
11+ class NewPseudoLabeler (self_training .DefaultPseudoLabeler ):
12+ """Compute pseudo labels based on model predictions, typically from a teacher model.
13+ By default, assumes that the first channel contains the transformed data and the second channel contains the background mask. # TODO update description
14+
15+ Args:
16+ activation: Activation function applied to the teacher prediction.
17+ confidence_threshold: Threshold for computing a mask for filtering the pseudo labels.
18+ If None is given no mask will be computed.
19+ threshold_from_both_sides: Whether to include both values bigger than the threshold
20+ and smaller than 1 - the thrhesold, or only values bigger than the threshold, in the mask.
21+ The former should be used for binary labels, the latter for for multiclass labels.
22+ confidence_mask_channel: A specific channel to use for computing the confidence mask.
23+ By default the confidence mask is computed across all channels independently.
24+ This is useful, if only one of the channels encodes a probability.
25+ raw_channel: # TODO add description
26+ background_mask_channel: # TODO add description
27+ """
28+ def __init__ (
29+ self ,
30+ activation : Optional [torch .nn .Module ] = None ,
31+ confidence_threshold : Optional [float ] = None ,
32+ threshold_from_both_sides : bool = True ,
33+ confidence_mask_channel : Optional [int ] = None ,
34+ raw_channel : Optional [int ] = 0 ,
35+ background_mask_channel : Optional [int ] = 1 ,
36+ ):
37+ super ().__init__ (activation , confidence_threshold , threshold_from_both_sides )
38+ self .raw_channel = raw_channel
39+ self .background_mask_channel = background_mask_channel
40+ self .confidence_mask_channel = confidence_mask_channel
41+
42+ def _subtract_background (self , pseudo_labels : torch .Tensor , background_mask : torch .Tensor ):
43+ bool_mask = background_mask .bool ()
44+ return pseudo_labels .masked_fill (bool_mask , 0 )
45+
46+ def __call__ (self , teacher : torch .nn .Module , input_ : torch .Tensor ) -> torch .Tensor :
47+ """Compute pseudo-labels.
48+
49+ Args:
50+ teacher: The teacher model.
51+ input_: The input for this batch.
52+
53+ Returns:
54+ The pseudo-labels.
55+ """
56+ if self .background_mask_channel is not None :
57+ if input_ .ndim != 5 :
58+ raise ValueError (f"Expect data with 5 dimensions (B, C, D, H, W), got shape { input_ .shape } ." )
59+
60+ if self .background_mask_channel > input_ .shape [1 ]:
61+ raise ValueError (f"Channel index { self .background_mask_channel } is out of bounds for shape { input_ .shape } ." )
62+
63+ background_mask = input_ [:, self .background_mask_channel ].unsqueeze (1 )
64+ input_ = input_ [:, self .raw_channel ].unsqueeze (1 )
65+
66+ pseudo_labels = teacher (input_ )
67+
68+ if self .activation is not None :
69+ pseudo_labels = self .activation (pseudo_labels )
70+ if self .confidence_threshold is None :
71+ label_mask = None
72+ else :
73+ mask_input = pseudo_labels if self .confidence_mask_channel is None \
74+ else pseudo_labels [self .confidence_mask_channel :(self .confidence_mask_channel + 1 )]
75+ label_mask = self ._compute_label_mask_both_sides (mask_input ) if self .threshold_from_both_sides \
76+ else self ._compute_label_mask_one_side (mask_input )
77+ if self .confidence_mask_channel is not None :
78+ size = (pseudo_labels .shape [0 ], pseudo_labels .shape [1 ], * ([- 1 ] * (pseudo_labels .ndim - 2 )))
79+ label_mask = label_mask .expand (* size )
80+
81+ if self .background_mask_channel is not None :
82+ pseudo_labels = self ._subtract_background (pseudo_labels , background_mask )
83+
84+ return pseudo_labels , label_mask
85+
2086
2187def mean_teacher_adaptation (
2288 name : str ,
@@ -36,13 +102,14 @@ def mean_teacher_adaptation(
36102 n_iterations : int = int (1e4 ),
37103 n_samples_train : Optional [int ] = None ,
38104 n_samples_val : Optional [int ] = None ,
39- train_mask_paths : Optional [Tuple [str ]] = None ,
40- val_mask_paths : Optional [Tuple [str ]] = None ,
105+ train_sample_mask_paths : Optional [Tuple [str ]] = None ,
106+ val_sample_mask_paths : Optional [Tuple [str ]] = None ,
107+ train_background_mask_paths : Optional [Tuple [str ]] = None ,
41108 patch_sampler : Optional [callable ] = None ,
42109 pseudo_label_sampler : Optional [callable ] = None ,
43110 device : int = 0 ,
44111) -> None :
45- """Run domain adaptation to transfer a network trained on a source domain for a supervised
112+ """Run domain adapation to transfer a network trained on a source domain for a supervised
46113 segmentation task to perform this task on a different target domain.
47114
48115 We support different domain adaptation settings:
@@ -85,10 +152,11 @@ def mean_teacher_adaptation(
85152 based on the patch_shape and size of the volumes used for training.
86153 n_samples_val: The number of val samples per epoch. By default this will be estimated
87154 based on the patch_shape and size of the volumes used for validation.
88- train_mask_paths: Boundary masks used by the patch sampler to accept or reject patches for training.
89- val_mask_paths: Sample masks used by the patch sampler to accept or reject patches for validation.
90- patch_sampler: Accept or reject patches based on a condition.
91- pseudo_label_sampler: Mask out regions of the pseudo labels where the teacher is not confident before updating the gradients.
155+ train_sample_mask_paths: Boundary masks used by the patch sampler to accept or reject patches for training.
156+ val_sample_mask_paths: Sample masks used by the patch sampler to accept or reject patches for validation.
157+ train_background_mask_paths: # TODO add description
158+ patch_sampler: A sampler for rejecting patches based on a defined conditon.
159+ pseudo_label_sampler: A sampler for rejecting pseudo-labels based on a defined condition.
92160 device: GPU ID for training.
93161 """
94162 assert (supervised_train_paths is None ) == (supervised_val_paths is None )
@@ -109,24 +177,29 @@ def mean_teacher_adaptation(
109177 if os .path .isdir (source_checkpoint ):
110178 model = torch_em .util .load_model (source_checkpoint )
111179 else :
112- model = torch .load (source_checkpoint , weights_only = False )
180+ model = torch .load (source_checkpoint )
113181 reinit_teacher = False
114182
115183 optimizer = torch .optim .Adam (model .parameters (), lr = 1e-4 )
116184 scheduler = torch .optim .lr_scheduler .ReduceLROnPlateau (optimizer , mode = "min" , factor = 0.5 , patience = 5 )
117185
118186 # self training functionality
119- pseudo_labeler = self_training .DefaultPseudoLabeler (confidence_threshold = confidence_threshold )
187+ if train_background_mask_paths is not None :
188+ pseudo_labeler = NewPseudoLabeler (confidence_threshold = confidence_threshold , background_mask_channel = 1 )
189+ else :
190+ pseudo_labeler = self_training .DefaultPseudoLabeler (confidence_threshold = confidence_threshold )
191+
120192 loss = self_training .DefaultSelfTrainingLoss ()
121193 loss_and_metric = self_training .DefaultSelfTrainingLossAndMetric ()
122-
194+
123195 unsupervised_train_loader = get_unsupervised_loader (
124196 data_paths = unsupervised_train_paths ,
125197 raw_key = raw_key ,
126198 patch_shape = patch_shape ,
127199 batch_size = batch_size ,
128200 n_samples = n_samples_train ,
129- sample_mask_paths = train_mask_paths ,
201+ sample_mask_paths = train_sample_mask_paths ,
202+ background_mask_paths = train_background_mask_paths ,
130203 sampler = patch_sampler
131204 )
132205 unsupervised_val_loader = get_unsupervised_loader (
@@ -135,7 +208,8 @@ def mean_teacher_adaptation(
135208 patch_shape = patch_shape ,
136209 batch_size = batch_size ,
137210 n_samples = n_samples_val ,
138- sample_mask_paths = val_mask_paths ,
211+ sample_mask_paths = val_sample_mask_paths ,
212+ background_mask_paths = None ,
139213 sampler = patch_sampler
140214 )
141215
@@ -178,138 +252,3 @@ def mean_teacher_adaptation(
178252 sampler = pseudo_label_sampler ,
179253 )
180254 trainer .fit (n_iterations )
181-
182-
183- # TODO patch shapes for other models
184- PATCH_SHAPES = {
185- "vesicles_3d" : [48 , 256 , 256 ],
186- }
187- """@private
188- """
189-
190-
191- def _get_paths (input_folder , pattern , resize_training_data , model_name , tmp_dir , val_fraction ):
192- files = sorted (glob (os .path .join (input_folder , "**" , pattern ), recursive = True ))
193- if len (files ) == 0 :
194- raise ValueError (f"Could not load any files from { input_folder } with pattern { pattern } " )
195-
196- # Heuristic: if we have less then 4 files then we crop a part of the volumes for validation.
197- # And resave the volumes.
198- resave_val_crops = len (files ) < 4
199-
200- # We only resave the data if we resave val crops or resize the training data
201- resave_data = resave_val_crops or resize_training_data
202- if not resave_data :
203- train_paths , val_paths = train_test_split (files , test_size = val_fraction )
204- return train_paths , val_paths
205-
206- train_paths , val_paths = [], []
207- for file_path in files :
208- file_name = os .path .basename (file_path )
209- data = open_file (file_path , mode = "r" )["data" ][:]
210-
211- if resize_training_data :
212- with mrcfile .open (file_path ) as f :
213- voxel_size = f .voxel_size
214- voxel_size = {ax : vox_size / 10.0 for ax , vox_size in zip ("xyz" , voxel_size .item ())}
215- scale = compute_scale_from_voxel_size (voxel_size , model_name )
216- scaler = _Scaler (scale , verbose = False )
217- data = scaler .sale_input (data )
218-
219- if resave_val_crops :
220- n_slices = data .shape [0 ]
221- val_slice = int ((1.0 - val_fraction ) * n_slices )
222- train_data , val_data = data [:val_slice ], data [val_slice :]
223-
224- train_path = os .path .join (tmp_dir , Path (file_name ).with_suffix (".h5" )).replace (".h5" , "_train.h5" )
225- with open_file (train_path , mode = "w" ) as f :
226- f .create_dataset ("data" , data = train_data , compression = "lzf" )
227- train_paths .append (train_path )
228-
229- val_path = os .path .join (tmp_dir , Path (file_name ).with_suffix (".h5" )).replace (".h5" , "_val.h5" )
230- with open_file (val_path , mode = "w" ) as f :
231- f .create_dataset ("data" , data = val_data , compression = "lzf" )
232- val_paths .append (val_path )
233-
234- else :
235- output_path = os .path .join (tmp_dir , Path (file_name ).with_suffix (".h5" ))
236- with open_file (output_path , mode = "w" ) as f :
237- f .create_dataset ("data" , data = data , compression = "lzf" )
238- train_paths .append (output_path )
239-
240- if not resave_val_crops :
241- train_paths , val_paths = train_test_split (train_paths , test_size = val_fraction )
242-
243- return train_paths , val_paths
244-
245-
246- def _parse_patch_shape (patch_shape , model_name ):
247- if patch_shape is None :
248- patch_shape = PATCH_SHAPES [model_name ]
249- return patch_shape
250-
251- def main ():
252- """@private
253- """
254- import argparse
255-
256- parser = argparse .ArgumentParser (
257- description = "Adapt a model to data from a different domain using unsupervised domain adaptation.\n \n "
258- "You can use this function to adapt the SynapseNet model for vesicle segmentation like this:\n "
259- "synapse_net.run_domain_adaptation -n adapted_model -i /path/to/data --file_pattern *.mrc --source_model vesicles_3d\n " # noqa
260- "The trained model will be saved in the folder 'checkpoints/adapted_model' (or whichever name you pass to the '-n' argument)." # noqa
261- "You can then use this model for segmentation with the SynapseNet GUI or CLI. "
262- "Check out the information below for details on the arguments of this function." ,
263- formatter_class = argparse .RawTextHelpFormatter
264- )
265- parser .add_argument ("--name" , "-n" , required = True , help = "The name of the model to be trained. " )
266- parser .add_argument ("--input_folder" , "-i" , required = True , help = "The folder with the training data." )
267- parser .add_argument ("--file_pattern" , default = "*" ,
268- help = "The pattern for selecting files for training. For example '*.mrc' to select mrc files." )
269- parser .add_argument ("--key" , help = "The internal file path for the training data. Will be derived from the file extension by default." ) # noqa
270- parser .add_argument (
271- "--source_model" ,
272- default = "vesicles_3d" ,
273- help = "The source model used for weight initialization of teacher and student model. "
274- "By default the model 'vesicles_3d' for vesicle segmentation in volumetric data is used."
275- )
276- parser .add_argument (
277- "--resize_training_data" , action = "store_true" ,
278- help = "Whether to resize the training data to fit the voxel size of the source model's trainign data."
279- )
280- parser .add_argument ("--n_iterations" , type = int , default = int (1e4 ), help = "The number of iterations for training." )
281- parser .add_argument (
282- "--patch_shape" , nargs = 3 , type = int ,
283- help = "The patch shape for training. By default the patch shape the source model was trained with is used."
284- )
285-
286- # More optional argument:
287- parser .add_argument ("--batch_size" , type = int , default = 1 , help = "The batch size for training." )
288- parser .add_argument ("--n_samples_train" , type = int , help = "The number of samples per epoch for training. If not given will be derived from the data size." ) # noqa
289- parser .add_argument ("--n_samples_val" , type = int , help = "The number of samples per epoch for validation. If not given will be derived from the data size." ) # noqa
290- parser .add_argument ("--val_fraction" , type = float , default = 0.15 , help = "The fraction of the data to use for validation. This has no effect if 'val_folder' and 'val_label_folder' were passed." ) # noqa
291- parser .add_argument ("--check" , action = "store_true" , help = "Visualize samples from the data loaders to ensure correct data instead of running training." ) # noqa
292-
293- args = parser .parse_args ()
294-
295- source_checkpoint = get_model_path (args .source_model )
296- patch_shape = _parse_patch_shape (args .patch_shape , args .source_model )
297- with tempfile .TemporaryDirectory () as tmp_dir :
298- unsupervised_train_paths , unsupervised_val_paths = _get_paths (
299- args .input , args .pattern , args .resize_training_data , args .source_model , tmp_dir , args .val_fraction ,
300- )
301- unsupervised_train_paths , raw_key = _derive_key_from_files (unsupervised_train_paths , args .key )
302-
303- mean_teacher_adaptation (
304- name = args .name ,
305- unsupervised_train_paths = unsupervised_train_paths ,
306- unsupervised_val_paths = unsupervised_val_paths ,
307- patch_shape = patch_shape ,
308- source_checkpoint = source_checkpoint ,
309- raw_key = raw_key ,
310- n_iterations = args .n_iterations ,
311- batch_size = args .batch_size ,
312- n_samples_train = args .n_samples_train ,
313- n_samples_val = args .n_samples_val ,
314- check = args .check ,
315- )
0 commit comments