1+ import argparse
12import os
3+ from datetime import datetime
24from glob import glob
35
46import torch_em
5-
67from torch_em .model import UNet3d
78
8- # DATA_ROOT = "/home/pape/Work/data/moser/lightsheet"
9- DATA_ROOT = "/scratch-grete/usr/nimcpape/data/moser/lightsheet"
9+ ROOT_CLUSTER = "/scratch-grete/usr/nimcpape/data/moser/lightsheet/training"
10+
11+
12+ def get_image_and_label_paths (root ):
13+ exclude_names = ["annotations" , "cp_masks" ]
14+ all_image_paths = sorted (glob (os .path .join (root , "**/**.tif" ), recursive = True ))
15+ all_image_paths = [
16+ path for path in all_image_paths if not any (exclude in path for exclude in exclude_names )
17+ ]
18+
19+ image_paths , label_paths = [], []
20+ label_extensions = ["_annotations.tif" ]
21+ for path in all_image_paths :
22+ folder , fname = os .path .split (path )
23+ fname = os .path .splitext (fname )[0 ]
24+ label_path = None
25+ for ext in label_extensions :
26+ candidate_label_path = os .path .join (folder , f"{ fname } { ext } " )
27+ if os .path .exists (candidate_label_path ):
28+ label_path = candidate_label_path
29+ break
30+
31+ if label_path is None :
32+ print ("Did not find annotations for" , path )
33+ print ("This image will not be used for training." )
34+ else :
35+ image_paths .append (path )
36+ label_paths .append (label_path )
37+
38+ assert len (image_paths ) == len (label_paths )
39+ return image_paths , label_paths
1040
1141
12- def get_paths (image_paths , label_paths , split , filter_empty ):
42+ def select_paths (image_paths , label_paths , split , filter_empty ):
1343 if filter_empty :
1444 image_paths = [imp for imp in image_paths if "empty" not in imp ]
1545 label_paths = [imp for imp in label_paths if "empty" not in imp ]
1646 assert len (image_paths ) == len (label_paths )
1747
1848 n_files = len (image_paths )
1949
20- train_fraction = 0.8
21- val_fraction = 0.1
50+ train_fraction = 0.85
2251
2352 n_train = int (train_fraction * n_files )
24- n_val = int (val_fraction * n_files )
2553 if split == "train" :
2654 image_paths = image_paths [:n_train ]
2755 label_paths = label_paths [:n_train ]
2856
2957 elif split == "val" :
30- image_paths = image_paths [n_train :( n_train + n_val ) ]
31- label_paths = label_paths [n_train :( n_train + n_val ) ]
58+ image_paths = image_paths [n_train :]
59+ label_paths = label_paths [n_train :]
3260
3361 return image_paths , label_paths
3462
3563
36- def get_loader (split , patch_shape , batch_size , filter_empty , train_on = ["default" ]):
37- image_paths , label_paths = [], []
38-
39- if "default" in train_on :
40- all_image_paths = sorted (glob (os .path .join (DATA_ROOT , "images" , "*.tif" )))
41- all_label_paths = sorted (glob (os .path .join (DATA_ROOT , "masks" , "*.tif" )))
42- this_image_paths , this_label_paths = get_paths (all_image_paths , all_label_paths , split , filter_empty )
43- image_paths .extend (this_image_paths )
44- label_paths .extend (this_label_paths )
64+ def get_loader (root , split , patch_shape , batch_size , filter_empty ):
65+ image_paths , label_paths = get_image_and_label_paths (root )
66+ this_image_paths , this_label_paths = select_paths (image_paths , label_paths , split , filter_empty )
4567
46- if "downsampled" in train_on :
47- all_image_paths = sorted (glob (os .path .join (DATA_ROOT , "images_s2" , "*.tif" )))
48- all_label_paths = sorted (glob (os .path .join (DATA_ROOT , "masks_s2" , "*.tif" )))
49- this_image_paths , this_label_paths = get_paths (all_image_paths , all_label_paths , split , filter_empty )
50- image_paths .extend (this_image_paths )
51- label_paths .extend (this_label_paths )
68+ assert len (this_image_paths ) == len (this_label_paths )
69+ assert len (this_image_paths ) > 0
5270
5371 label_transform = torch_em .transform .label .PerObjectDistanceTransform (
5472 distances = True , boundary_distances = True , foreground = True ,
@@ -59,7 +77,7 @@ def get_loader(split, patch_shape, batch_size, filter_empty, train_on=["default"
5977 elif split == "val" :
6078 n_samples = 20 * batch_size
6179
62- sampler = torch_em .data .sampler .MinInstanceSampler (p_reject = 0.95 )
80+ sampler = torch_em .data .sampler .MinInstanceSampler (p_reject = 0.8 )
6381 loader = torch_em .default_segmentation_loader (
6482 raw_paths = image_paths , raw_key = None , label_paths = label_paths , label_key = None ,
6583 batch_size = batch_size , patch_shape = patch_shape , label_transform = label_transform ,
@@ -69,26 +87,45 @@ def get_loader(split, patch_shape, batch_size, filter_empty, train_on=["default"
6987 return loader
7088
7189
72- def main (check_loaders = False ):
73- # Parameters for training:
90+ def main ():
91+ parser = argparse .ArgumentParser ()
92+ parser .add_argument (
93+ "--root" , "-i" , help = "The root folder with the annotated training crops." ,
94+ default = ROOT_CLUSTER ,
95+ )
96+ parser .add_argument (
97+ "--batch_size" , "-b" , help = "The batch size for training. Set to 8 by default."
98+ "You may need to choose a smaller batch size to train on yoru GPU." ,
99+ default = 8 , type = int ,
100+ )
101+ parser .add_argument (
102+ "--check_loaders" , "-l" , action = "store_true" ,
103+ help = "Visualize the data loader output instead of starting a training run."
104+ )
105+ parser .add_argument (
106+ "--filter_empty" , "-f" , action = "store_true" ,
107+ help = "Whether to exclude blocks with empty annotations from the training process."
108+ )
109+ parser .add_argument (
110+ "--name" , help = "Optional name for the model to be trained. If not given the current date is used."
111+ )
112+ args = parser .parse_args ()
113+ root = args .root
114+ batch_size = args .batch_size
115+ check_loaders = args .check_loaders
116+ filter_empty = args .filter_empty
117+ run_name = datetime .now ().strftime ("%Y%m%d" ) if args .name is None else args .name
118+
119+ # Parameters for training on A100.
74120 n_iterations = 1e5
75- batch_size = 8
76- filter_empty = False
77- train_on = ["downsampled" ]
78- # train_on = ["downsampled", "default"]
79-
80- patch_shape = (32 , 128 , 128 ) if "downsampled" in train_on else (64 , 128 , 128 )
121+ patch_shape = (64 , 128 , 128 )
81122
82123 # The U-Net.
83124 model = UNet3d (in_channels = 1 , out_channels = 3 , initial_features = 32 , final_activation = "Sigmoid" )
84125
85126 # Create the training loader with train and val set.
86- train_loader = get_loader (
87- "train" , patch_shape , batch_size , filter_empty = filter_empty , train_on = train_on
88- )
89- val_loader = get_loader (
90- "val" , patch_shape , batch_size , filter_empty = filter_empty , train_on = train_on
91- )
127+ train_loader = get_loader (root , "train" , patch_shape , batch_size , filter_empty = filter_empty )
128+ val_loader = get_loader (root , "val" , patch_shape , batch_size , filter_empty = filter_empty )
92129
93130 if check_loaders :
94131 from torch_em .util .debug import check_loader
@@ -99,12 +136,7 @@ def main(check_loaders=False):
99136 loss = torch_em .loss .distance_based .DiceBasedDistanceLoss (mask_distances_in_bg = True )
100137
101138 # Create the trainer.
102- name = "cochlea_distance_unet"
103- if filter_empty :
104- name += "-filter-empty"
105- if train_on == ["downsampled" ]:
106- name += "-train-downsampled"
107-
139+ name = f"cochlea_distance_unet_{ run_name } "
108140 trainer = torch_em .default_segmentation_trainer (
109141 name = name ,
110142 model = model ,
@@ -123,4 +155,4 @@ def main(check_loaders=False):
123155
124156
125157if __name__ == "__main__" :
126- main (check_loaders = False )
158+ main ()
0 commit comments