1+ import os
2+ from glob import glob
13from typing import Optional , Tuple
24
35import torch
46import torch_em
7+ from sklearn .model_selection import train_test_split
58from torch_em .model import AnisotropicUNet , UNet2d
69
710
@@ -95,6 +98,7 @@ def get_supervised_loader(
9598 sampler : Optional [callable ] = None ,
9699 ignore_label : Optional [int ] = None ,
97100 label_transform : Optional [callable ] = None ,
101+ label_paths : Optional [Tuple [str ]] = None ,
98102 ** loader_kwargs ,
99103) -> torch .utils .data .DataLoader :
100104 """Get a dataloader for supervised segmentation training.
@@ -118,6 +122,8 @@ def get_supervised_loader(
118122 ignored in the loss computation. By default this option is not used.
119123 label_transform: Label transform that is applied to the segmentation to compute the targets.
120124 If no label transform is passed (the default) a boundary transform is used.
125+ label_paths: Optional paths containing the labels / annotations for training.
126+ If not given, the labels are expected to be contained in the `data_paths`.
121127 loader_kwargs: Additional keyword arguments for the dataloader.
122128
123129 Returns:
@@ -155,9 +161,14 @@ def get_supervised_loader(
155161 if sampler is None :
156162 sampler = torch_em .data .sampler .MinInstanceSampler (min_num_instances = 4 )
157163
164+ if label_paths is None :
165+ label_paths = data_paths
166+ elif len (label_paths ) != len (data_paths ):
167+ raise ValueError (f"Data paths and label paths don't match: { len (data_paths )} != { len (label_paths )} " )
168+
158169 loader = torch_em .default_segmentation_loader (
159170 data_paths , raw_key ,
160- data_paths , label_key , sampler = sampler ,
171+ label_paths , label_key , sampler = sampler ,
161172 batch_size = batch_size , patch_shape = patch_shape , ndim = ndim ,
162173 is_seg_dataset = True , label_transform = label_transform , transform = transform ,
163174 num_workers = num_workers , shuffle = shuffle , n_samples = n_samples ,
@@ -177,6 +188,8 @@ def supervised_training(
177188 batch_size : int = 1 ,
178189 lr : float = 1e-4 ,
179190 n_iterations : int = int (1e5 ),
191+ train_label_paths : Optional [Tuple [str ]] = None ,
192+ val_label_paths : Optional [Tuple [str ]] = None ,
180193 train_rois : Optional [Tuple [Tuple [slice ]]] = None ,
181194 val_rois : Optional [Tuple [Tuple [slice ]]] = None ,
182195 sampler : Optional [callable ] = None ,
@@ -210,6 +223,10 @@ def supervised_training(
210223 batch_size: The batch size for training.
211224 lr: The initial learning rate.
212225 n_iterations: The number of iterations to train for.
226+ train_label_paths: Optional paths containing the label data for training.
227+ If not given, the labels are expected to be part of `train_paths`.
228+ val_label_paths: Optional paths containing the label data for validation.
229+ If not given, the labels are expected to be part of `val_paths`.
213230 train_rois: Optional region of interests for training.
214231 val_rois: Optional region of interests for validation.
215232 sampler: Optional sampler for selecting blocks for training.
@@ -231,11 +248,11 @@ def supervised_training(
231248 train_loader = get_supervised_loader (train_paths , raw_key , label_key , patch_shape , batch_size ,
232249 n_samples = n_samples_train , rois = train_rois , sampler = sampler ,
233250 ignore_label = ignore_label , label_transform = label_transform ,
234- ** loader_kwargs )
251+ label_paths = train_label_paths , ** loader_kwargs )
235252 val_loader = get_supervised_loader (val_paths , raw_key , label_key , patch_shape , batch_size ,
236253 n_samples = n_samples_val , rois = val_rois , sampler = sampler ,
237254 ignore_label = ignore_label , label_transform = label_transform ,
238- ** loader_kwargs )
255+ label_paths = val_label_paths , ** loader_kwargs )
239256
240257 if check :
241258 from torch_em .util .debug import check_loader
@@ -287,3 +304,105 @@ def supervised_training(
287304 metric = metric ,
288305 )
289306 trainer .fit (n_iterations )
307+
308+
309+ def _parse_input_folder (folder , pattern , key ):
310+ files = sorted (glob (os .path .join (folder , "**" , pattern )))
311+ # Get all file extensions (general wild-cards may pick up files with multiple extensions).
312+ extensions = [os .path .splitext (ff )[1 ] for ff in files ]
313+
314+ # If we have more than 1 file extension we just use the key that was passed,
315+ # as it is unclear how to derive a consistent key.
316+ if len (extensions ) > 1 :
317+ return files , key
318+
319+ ext = extensions [0 ]
320+ extension_to_key = {".tif" : None , ".mrc" : "data" , ".rec" : "data" }
321+
322+ # Derive the key from the extension if the key is None.
323+ if key is None and ext in extension_to_key :
324+ key = extension_to_key [ext ]
325+ # If the key is None and can't be derived raise an error.
326+ elif key is None and ext not in extension_to_key :
327+ raise ValueError (
328+ f"You have not passed a key for the data in { folder } , but the key could not be derived for{ ext } format."
329+ )
330+ # If the key was passed and doesn't match the extension raise an error.
331+ elif key is not None and ext in extension_to_key and key != extension_to_key [ext ]:
332+ raise ValueError (
333+ f"The expected key { extension_to_key [ext ]} for format { ext } did not match the passed key { key } ."
334+ )
335+ return files , key
336+
337+
338+ def _parse_input_files (args ):
339+ train_image_paths , raw_key = _parse_input_folder (args .train_folder , args .image_file_pattern , args .raw_key )
340+ train_label_paths , label_key = _parse_input_folder (args .label_folder , args .label_file_pattern , args .label_key )
341+ if len (train_image_paths ) != len (train_label_paths ):
342+ raise ValueError (
343+ f"The image and label paths parsed from { args .train_folder } and { args .label_folder } don't match."
344+ f"The image folder contains { len (train_image_paths )} , the label folder contains { len (train_label_paths )} ."
345+ )
346+
347+ if args .val_folder is None :
348+ if args .val_label_folder is not None :
349+ raise ValueError ("You have passed a val_label_folder, but not a val_folder." )
350+ train_image_paths , val_image_paths , train_label_paths , val_label_paths = train_test_split (
351+ train_image_paths , train_label_paths , test_size = args .val_fraction , random_state = 42
352+ )
353+ else :
354+ if args .val_label_folder is None :
355+ raise ValueError ("You have passed a val_folder, but not a val_label_folder." )
356+ val_image_paths = _parse_input_folder (args .val_image_folder , args .image_file_pattern , raw_key )
357+ val_label_paths = _parse_input_folder (args .val_label_folder , args .label_file_pattern , label_key )
358+
359+ return train_image_paths , train_label_paths , val_image_paths , val_label_paths , raw_key , label_key
360+
361+
362+ # TODO enable initialization with a pre-trained model.
363+ def main ():
364+ """@private
365+ """
366+ import argparse
367+
368+ parser = argparse .ArgumentParser (
369+ description = "Train a model for foreground and boundary segmentation via supervised learning."
370+ )
371+ parser .add_argument ("-n" , "--name" , required = True , help = "The name of the model to be trained." )
372+ parser .add_argument ("-p" , "--patch_shape" , nargs = 3 , type = int , help = "The patch shape for training." )
373+
374+ # Folders with training data, containing raw/image data and labels.
375+ parser .add_argument ("--i" , "--train_folder" , required = True , help = "The input folder with the training image data." )
376+ parser .add_argument ("--image_file_pattern" , default = "*" ,
377+ help = "The pattern for selecting image files. For example, '*.mrc' to select all mrc files." )
378+ parser .add_argument ("--raw_key" ,
379+ help = "The internal path for the raw data. If not given, will be determined based on the file extension." ) # noqa
380+ parser .add_argument ("-l" , "--label_folder" , required = True , help = "The input folder with the training labels." )
381+ parser .add_argument ("--label_file_pattern" , default = "*" ,
382+ help = "The pattern for selecting label files. For example, '*.tif' to select all tif files." )
383+ parser .add_argument ("--label_key" ,
384+ help = "The internal path for the label data. If not given, will be determined based on the file extension." ) # noqa
385+
386+ # Optional folders with validation data. If not given the training data is split into train/val.
387+ parser .add_argument ("--val_folder" ,
388+ help = "The input folder with the validation data. If not given the training data will be split for validation" ) # noqa
389+ parser .add_argument ("--val_label_folder" ,
390+ help = "The input folder with the validation labels. If not given the training data will be split for validation." ) # noqa
391+
392+ # More optional argument:
393+ parser .add_argument ("--batch_size" , type = int , default = 1 , help = "The batch size for training." )
394+ parser .add_argument ("--n_samples_train" , type = int , help = "The number of samples per epoch for training. If not given will be derived from the data size." ) # noqa
395+ parser .add_argument ("--n_samples_val" , type = int , help = "The number of samples per epoch for validation. If not given will be derived from the data size." ) # noqa
396+ parser .add_argument ("--val_fraction" , type = float , default = 0.15 , help = "The fraction of the data to use for validation. This has no effect if 'val_folder' and 'val_label_folder' were passed." ) # noqa
397+ args = parser .parse_args ()
398+
399+ train_image_paths , train_label_paths , val_image_paths , val_label_paths , raw_key , label_key = \
400+ _parse_input_files (args )
401+
402+ supervised_training (
403+ name = args .name , train_paths = train_image_paths , val_paths = val_image_paths ,
404+ train_label_paths = train_label_paths , val_label_paths = val_label_paths ,
405+ raw_key = raw_key , label_key = label_key , patch_shape = args .patch_shape , batch_size = args .batch_size ,
406+ n_samples_train = args .n_samples_train , n_samples_val = args .n_samples_val ,
407+ check = args .check ,
408+ )
0 commit comments