1+ import argparse
12import os
3+ from datetime import datetime
24from glob import glob
35
46import torch_em
5-
67from torch_em .model import UNet3d
78
8- # DATA_ROOT = "/home/pape/Work/data/moser/lightsheet"
9- DATA_ROOT = "/scratch-grete/usr/nimcpape/data/moser/lightsheet"
9+ ROOT_CLUSTER = "/scratch-grete/usr/nimcpape/data/moser/lightsheet"
10+
11+
12+ def get_image_and_label_paths (root ):
13+ exclude_names = ["annotations" , "cp_masks" ]
14+ all_image_paths = sorted (glob (os .path .join (root , "**/**.tif" ), recursive = True ))
15+ all_image_paths = [
16+ path for path in all_image_paths if not any (exclude in path for exclude in exclude_names )
17+ ]
18+
19+ image_paths , label_paths = [], []
20+ label_extensions = ["_annotations.tif" ]
21+ for path in all_image_paths :
22+ folder , fname = os .path .split (path )
23+ fname = os .path .splitext (fname )[0 ]
24+ label_path = None
25+ for ext in label_extensions :
26+ candidate_label_path = os .path .join (folder , f"{ fname } { ext } " )
27+ if os .path .exists (candidate_label_path ):
28+ label_path = candidate_label_path
29+ break
30+
31+ if label_path is None :
32+ print ("Did not find annotations for" , path )
33+ print ("This image will not be used for training." )
34+ else :
35+ image_paths .append (path )
36+ label_paths .append (label_path )
37+
38+ assert len (image_paths ) == len (label_paths )
39+ return image_paths , label_paths
1040
1141
12- def get_paths (image_paths , label_paths , split , filter_empty ):
42+ def select_paths (image_paths , label_paths , split , filter_empty ):
1343 if filter_empty :
1444 image_paths = [imp for imp in image_paths if "empty" not in imp ]
1545 label_paths = [imp for imp in label_paths if "empty" not in imp ]
1646 assert len (image_paths ) == len (label_paths )
1747
1848 n_files = len (image_paths )
1949
20- train_fraction = 0.8
21- val_fraction = 0.1
50+ train_fraction = 0.85
2251
2352 n_train = int (train_fraction * n_files )
24- n_val = int (val_fraction * n_files )
2553 if split == "train" :
2654 image_paths = image_paths [:n_train ]
2755 label_paths = label_paths [:n_train ]
2856
2957 elif split == "val" :
30- image_paths = image_paths [n_train :( n_train + n_val ) ]
31- label_paths = label_paths [n_train :( n_train + n_val ) ]
58+ image_paths = image_paths [n_train :]
59+ label_paths = label_paths [n_train :]
3260
3361 return image_paths , label_paths
3462
3563
36- def get_loader (split , patch_shape , batch_size , filter_empty , train_on = ["default" ]):
37- image_paths , label_paths = [], []
38-
39- if "default" in train_on :
40- all_image_paths = sorted (glob (os .path .join (DATA_ROOT , "images" , "*.tif" )))
41- all_label_paths = sorted (glob (os .path .join (DATA_ROOT , "masks" , "*.tif" )))
42- this_image_paths , this_label_paths = get_paths (all_image_paths , all_label_paths , split , filter_empty )
43- image_paths .extend (this_image_paths )
44- label_paths .extend (this_label_paths )
64+ def get_loader (root , split , patch_shape , batch_size , filter_empty ):
65+ image_paths , label_paths = get_image_and_label_paths (root )
66+ this_image_paths , this_label_paths = select_paths (image_paths , label_paths , split , filter_empty )
4567
46- if "downsampled" in train_on :
47- all_image_paths = sorted (glob (os .path .join (DATA_ROOT , "images_s2" , "*.tif" )))
48- all_label_paths = sorted (glob (os .path .join (DATA_ROOT , "masks_s2" , "*.tif" )))
49- this_image_paths , this_label_paths = get_paths (all_image_paths , all_label_paths , split , filter_empty )
50- image_paths .extend (this_image_paths )
51- label_paths .extend (this_label_paths )
68+ assert len (this_image_paths ) == len (this_label_paths )
69+ assert len (this_image_paths ) > 0
5270
5371 label_transform = torch_em .transform .label .PerObjectDistanceTransform (
5472 distances = True , boundary_distances = True , foreground = True ,
@@ -69,26 +87,40 @@ def get_loader(split, patch_shape, batch_size, filter_empty, train_on=["default"
6987 return loader
7088
7189
72- def main (check_loaders = False ):
73- # Parameters for training:
90+ def main ():
91+ parser = argparse .ArgumentParser ()
92+ parser .add_argument (
93+ "--root" , "-i" , help = "The root folder with the annotated training crops." ,
94+ default = ROOT_CLUSTER ,
95+ )
96+ parser .add_argument (
97+ "--check_loaders" , "-l" , action = "store_true" ,
98+ help = "Visualize the data loader output instead of starting a training run."
99+ )
100+ parser .add_argument (
101+ "--filter_empty" , "-f" , action = "store_true" ,
102+ help = "Whether to exclude blocks with empty annotations from the training process."
103+ )
104+ parser .add_argument (
105+ "--name" , help = "Optional name for the model to be trained. If not given the current date is used."
106+ )
107+ args = parser .parse_args ()
108+ root = args .root
109+ check_loaders = args .check_loaders
110+ filter_empty = args .filter_empty
111+ run_name = datetime .now ().strftime ("%Y%m%d" ) if args .name is None else args .name
112+
113+ # Parameters for training on A100.
74114 n_iterations = 1e5
75115 batch_size = 8
76- filter_empty = False
77- train_on = ["downsampled" ]
78- # train_on = ["downsampled", "default"]
79-
80- patch_shape = (32 , 128 , 128 ) if "downsampled" in train_on else (64 , 128 , 128 )
116+ patch_shape = (64 , 128 , 128 )
81117
82118 # The U-Net.
83119 model = UNet3d (in_channels = 1 , out_channels = 3 , initial_features = 32 , final_activation = "Sigmoid" )
84120
85121 # Create the training loader with train and val set.
86- train_loader = get_loader (
87- "train" , patch_shape , batch_size , filter_empty = filter_empty , train_on = train_on
88- )
89- val_loader = get_loader (
90- "val" , patch_shape , batch_size , filter_empty = filter_empty , train_on = train_on
91- )
122+ train_loader = get_loader (root , "train" , patch_shape , batch_size , filter_empty = filter_empty )
123+ val_loader = get_loader (root , "val" , patch_shape , batch_size , filter_empty = filter_empty )
92124
93125 if check_loaders :
94126 from torch_em .util .debug import check_loader
@@ -99,12 +131,7 @@ def main(check_loaders=False):
99131 loss = torch_em .loss .distance_based .DiceBasedDistanceLoss (mask_distances_in_bg = True )
100132
101133 # Create the trainer.
102- name = "cochlea_distance_unet"
103- if filter_empty :
104- name += "-filter-empty"
105- if train_on == ["downsampled" ]:
106- name += "-train-downsampled"
107-
134+ name = f"cochlea_distance_unet_{ run_name } "
108135 trainer = torch_em .default_segmentation_trainer (
109136 name = name ,
110137 model = model ,
@@ -123,4 +150,4 @@ def main(check_loaders=False):
123150
124151
125152if __name__ == "__main__" :
126- main (check_loaders = False )
153+ main ()
0 commit comments