11import os
2- import tempfile
3- from glob import glob
4- from pathlib import Path
52from typing import Optional , Tuple
63
7- import mrcfile
84import torch
95import torch_em
106import torch_em .self_training as self_training
11- from elf .io import open_file
12- from sklearn .model_selection import train_test_split
137
148from .semisupervised_training import get_unsupervised_loader
15- from .supervised_training import (
16- get_2d_model , get_3d_model , get_supervised_loader , _determine_ndim , _derive_key_from_files
17- )
18- from ..inference .inference import get_model_path , compute_scale_from_voxel_size
19- from ..inference .util import _Scaler
9+ from .supervised_training import get_2d_model , get_3d_model , get_supervised_loader , _determine_ndim
2010
2111def mean_teacher_adaptation (
2212 name : str ,
@@ -38,10 +28,11 @@ def mean_teacher_adaptation(
3828 n_samples_val : Optional [int ] = None ,
3929 train_mask_paths : Optional [Tuple [str ]] = None ,
4030 val_mask_paths : Optional [Tuple [str ]] = None ,
41- sampler : Optional [callable ] = None ,
31+ patch_sampler : Optional [callable ] = None ,
32+ pseudo_label_sampler : Optional [callable ] = None ,
4233 device : int = 0 ,
4334) -> None :
44- """Run domain adapation to transfer a network trained on a source domain for a supervised
35+ """Run domain adaptation to transfer a network trained on a source domain for a supervised
4536 segmentation task to perform this task on a different target domain.
4637
4738 We support different domain adaptation settings:
@@ -84,10 +75,11 @@ def mean_teacher_adaptation(
8475 based on the patch_shape and size of the volumes used for training.
8576 n_samples_val: The number of val samples per epoch. By default this will be estimated
8677 based on the patch_shape and size of the volumes used for validation.
87- train_mask_paths: Boundary masks used by the sampler to accept or reject patches for training.
88- val_mask_paths: Boundary masks used by the sampler to accept or reject patches for validation.
89- sampler: Accept or reject patches based on a condition.
90- device: GPU ID for training.
78+ train_mask_paths: Boundary masks used by the patch sampler to accept or reject patches for training.
79+ val_mask_paths: Sample masks used by the patch sampler to accept or reject patches for validation.
80+ patch_sampler: Accept or reject patches based on a condition.
81+ pseudo_label_sampler: Mask out regions of the pseudo labels where the teacher is not confident before updating the gradients.
82+ device: GPU ID for training.
9183 """
9284 assert (supervised_train_paths is None ) == (supervised_val_paths is None )
9385 is_2d , _ = _determine_ndim (patch_shape )
@@ -103,11 +95,11 @@ def mean_teacher_adaptation(
10395 model = get_3d_model (out_channels = 2 )
10496 reinit_teacher = True
10597 else :
106- print ("Mean teacehr training initialized from source model:" , source_checkpoint )
98+ print ("Mean teacher training initialized from source model:" , source_checkpoint )
10799 if os .path .isdir (source_checkpoint ):
108100 model = torch_em .util .load_model (source_checkpoint )
109101 else :
110- model = torch .load (source_checkpoint , weights_only = False )
102+ model = torch .load (source_checkpoint )
111103 reinit_teacher = False
112104
113105 optimizer = torch .optim .Adam (model .parameters (), lr = 1e-4 )
@@ -117,23 +109,24 @@ def mean_teacher_adaptation(
117109 pseudo_labeler = self_training .DefaultPseudoLabeler (confidence_threshold = confidence_threshold )
118110 loss = self_training .DefaultSelfTrainingLoss ()
119111 loss_and_metric = self_training .DefaultSelfTrainingLossAndMetric ()
120-
112+
121113 unsupervised_train_loader = get_unsupervised_loader (
122114 data_paths = unsupervised_train_paths ,
123115 raw_key = raw_key ,
124116 patch_shape = patch_shape ,
125117 batch_size = batch_size ,
126118 n_samples = n_samples_train ,
127- boundary_mask_paths = train_mask_paths ,
128- sampler = sampler
119+ sample_mask_paths = train_mask_paths ,
120+ sampler = patch_sampler
129121 )
130122 unsupervised_val_loader = get_unsupervised_loader (
131123 data_paths = unsupervised_val_paths ,
132124 raw_key = raw_key ,
133125 patch_shape = patch_shape ,
134126 batch_size = batch_size ,
135127 n_samples = n_samples_val ,
136- boundary_mask_paths = val_mask_paths , sampler = sampler
128+ sample_mask_paths = val_mask_paths ,
129+ sampler = patch_sampler
137130 )
138131
139132 if supervised_train_paths is not None :
@@ -172,142 +165,6 @@ def mean_teacher_adaptation(
172165 device = device ,
173166 reinit_teacher = reinit_teacher ,
174167 save_root = save_root ,
175- sampler = None , # TODO currently set to none cause I didn't want to pass the same sampler used by get_unsupervised_loader
168+ sampler = pseudo_label_sampler ,
176169 )
177170 trainer .fit (n_iterations )
178-
179-
180- # TODO patch shapes for other models
181- PATCH_SHAPES = {
182- "vesicles_3d" : [48 , 256 , 256 ],
183- }
184- """@private
185- """
186-
187-
188- def _get_paths (input_folder , pattern , resize_training_data , model_name , tmp_dir , val_fraction ):
189- files = sorted (glob (os .path .join (input_folder , "**" , pattern ), recursive = True ))
190- if len (files ) == 0 :
191- raise ValueError (f"Could not load any files from { input_folder } with pattern { pattern } " )
192-
193- # Heuristic: if we have less then 4 files then we crop a part of the volumes for validation.
194- # And resave the volumes.
195- resave_val_crops = len (files ) < 4
196-
197- # We only resave the data if we resave val crops or resize the training data
198- resave_data = resave_val_crops or resize_training_data
199- if not resave_data :
200- train_paths , val_paths = train_test_split (files , test_size = val_fraction )
201- return train_paths , val_paths
202-
203- train_paths , val_paths = [], []
204- for file_path in files :
205- file_name = os .path .basename (file_path )
206- data = open_file (file_path , mode = "r" )["data" ][:]
207-
208- if resize_training_data :
209- with mrcfile .open (file_path ) as f :
210- voxel_size = f .voxel_size
211- voxel_size = {ax : vox_size / 10.0 for ax , vox_size in zip ("xyz" , voxel_size .item ())}
212- scale = compute_scale_from_voxel_size (voxel_size , model_name )
213- scaler = _Scaler (scale , verbose = False )
214- data = scaler .sale_input (data )
215-
216- if resave_val_crops :
217- n_slices = data .shape [0 ]
218- val_slice = int ((1.0 - val_fraction ) * n_slices )
219- train_data , val_data = data [:val_slice ], data [val_slice :]
220-
221- train_path = os .path .join (tmp_dir , Path (file_name ).with_suffix (".h5" )).replace (".h5" , "_train.h5" )
222- with open_file (train_path , mode = "w" ) as f :
223- f .create_dataset ("data" , data = train_data , compression = "lzf" )
224- train_paths .append (train_path )
225-
226- val_path = os .path .join (tmp_dir , Path (file_name ).with_suffix (".h5" )).replace (".h5" , "_val.h5" )
227- with open_file (val_path , mode = "w" ) as f :
228- f .create_dataset ("data" , data = val_data , compression = "lzf" )
229- val_paths .append (val_path )
230-
231- else :
232- output_path = os .path .join (tmp_dir , Path (file_name ).with_suffix (".h5" ))
233- with open_file (output_path , mode = "w" ) as f :
234- f .create_dataset ("data" , data = data , compression = "lzf" )
235- train_paths .append (output_path )
236-
237- if not resave_val_crops :
238- train_paths , val_paths = train_test_split (train_paths , test_size = val_fraction )
239-
240- return train_paths , val_paths
241-
242-
243- def _parse_patch_shape (patch_shape , model_name ):
244- if patch_shape is None :
245- patch_shape = PATCH_SHAPES [model_name ]
246- return patch_shape
247-
248-
249- def main ():
250- """@private
251- """
252- import argparse
253-
254- parser = argparse .ArgumentParser (
255- description = "Adapt a model to data from a different domain using unsupervised domain adaptation.\n \n "
256- "You can use this function to adapt the SynapseNet model for vesicle segmentation like this:\n "
257- "synapse_net.run_domain_adaptation -n adapted_model -i /path/to/data --file_pattern *.mrc --source_model vesicles_3d\n " # noqa
258- "The trained model will be saved in the folder 'checkpoints/adapted_model' (or whichever name you pass to the '-n' argument)." # noqa
259- "You can then use this model for segmentation with the SynapseNet GUI or CLI. "
260- "Check out the information below for details on the arguments of this function." ,
261- formatter_class = argparse .RawTextHelpFormatter
262- )
263- parser .add_argument ("--name" , "-n" , required = True , help = "The name of the model to be trained. " )
264- parser .add_argument ("--input_folder" , "-i" , required = True , help = "The folder with the training data." )
265- parser .add_argument ("--file_pattern" , default = "*" ,
266- help = "The pattern for selecting files for training. For example '*.mrc' to select mrc files." )
267- parser .add_argument ("--key" , help = "The internal file path for the training data. Will be derived from the file extension by default." ) # noqa
268- parser .add_argument (
269- "--source_model" ,
270- default = "vesicles_3d" ,
271- help = "The source model used for weight initialization of teacher and student model. "
272- "By default the model 'vesicles_3d' for vesicle segmentation in volumetric data is used."
273- )
274- parser .add_argument (
275- "--resize_training_data" , action = "store_true" ,
276- help = "Whether to resize the training data to fit the voxel size of the source model's trainign data."
277- )
278- parser .add_argument ("--n_iterations" , type = int , default = int (1e4 ), help = "The number of iterations for training." )
279- parser .add_argument (
280- "--patch_shape" , nargs = 3 , type = int ,
281- help = "The patch shape for training. By default the patch shape the source model was trained with is used."
282- )
283-
284- # More optional argument:
285- parser .add_argument ("--batch_size" , type = int , default = 1 , help = "The batch size for training." )
286- parser .add_argument ("--n_samples_train" , type = int , help = "The number of samples per epoch for training. If not given will be derived from the data size." ) # noqa
287- parser .add_argument ("--n_samples_val" , type = int , help = "The number of samples per epoch for validation. If not given will be derived from the data size." ) # noqa
288- parser .add_argument ("--val_fraction" , type = float , default = 0.15 , help = "The fraction of the data to use for validation. This has no effect if 'val_folder' and 'val_label_folder' were passed." ) # noqa
289- parser .add_argument ("--check" , action = "store_true" , help = "Visualize samples from the data loaders to ensure correct data instead of running training." ) # noqa
290-
291- args = parser .parse_args ()
292-
293- source_checkpoint = get_model_path (args .source_model )
294- patch_shape = _parse_patch_shape (args .patch_shape , args .source_model )
295- with tempfile .TemporaryDirectory () as tmp_dir :
296- unsupervised_train_paths , unsupervised_val_paths = _get_paths (
297- args .input , args .pattern , args .resize_training_data , args .source_model , tmp_dir , args .val_fraction ,
298- )
299- unsupervised_train_paths , raw_key = _derive_key_from_files (unsupervised_train_paths , args .key )
300-
301- mean_teacher_adaptation (
302- name = args .name ,
303- unsupervised_train_paths = unsupervised_train_paths ,
304- unsupervised_val_paths = unsupervised_val_paths ,
305- patch_shape = patch_shape ,
306- source_checkpoint = source_checkpoint ,
307- raw_key = raw_key ,
308- n_iterations = args .n_iterations ,
309- batch_size = args .batch_size ,
310- n_samples_train = args .n_samples_train ,
311- n_samples_val = args .n_samples_val ,
312- check = args .check ,
313- )
0 commit comments