@@ -38,18 +38,19 @@ def get_image_and_label_paths(root):
3838            label_paths .append (label_path )
3939
4040    assert  len (image_paths ) ==  len (label_paths )
41-     return  image_paths , label_paths 
41+     return  image_paths , label_paths ,  None 
4242
4343
4444def  get_image_and_label_paths_sep_folders (root ):
4545    image_paths  =  sorted (glob (os .path .join (root , "images" , "**" , "*.tif" ), recursive = True ))
4646    label_paths  =  sorted (glob (os .path .join (root , "labels" , "**" , "*.tif" ), recursive = True ))
4747    assert  len (image_paths ) ==  len (label_paths )
4848
49-     return  image_paths , label_paths 
49+     stratify  =  [os .path .basename (os .path .dirname (f )) for  f  in  image_paths ]
50+     return  image_paths , label_paths , stratify 
5051
5152
52- def  select_paths (image_paths , label_paths , split , filter_empty , random_split = True ):
53+ def  select_paths (image_paths , label_paths , split , filter_empty , stratify ,  random_split = True ):
5354    if  filter_empty :
5455        image_paths  =  [imp  for  imp  in  image_paths  if  "empty"  not  in   imp ]
5556        label_paths  =  [imp  for  imp  in  label_paths  if  "empty"  not  in   imp ]
@@ -60,12 +61,16 @@ def select_paths(image_paths, label_paths, split, filter_empty, random_split=Tru
6061
6162    n_train  =  int (train_fraction  *  n_files )
6263    if  split  ==  "train"  and  random_split :
63-         image_paths , _ , label_paths , _  =  train_test_split (image_paths , label_paths , train_size = n_train , random_state = 42 )
64+         image_paths , _ , label_paths , _  =  train_test_split (
65+             image_paths , label_paths , train_size = n_train , random_state = 42 , stratify = stratify 
66+         )
6467    elif  split  ==  "train" :
6568        image_paths  =  image_paths [:n_train ]
6669        label_paths  =  label_paths [:n_train ]
6770    elif  split  ==  "val"  and  random_split :
68-         _ , image_paths , _ , label_paths  =  train_test_split (image_paths , label_paths , train_size = n_train , random_state = 42 )
71+         _ , image_paths , _ , label_paths  =  train_test_split (
72+             image_paths , label_paths , train_size = n_train , random_state = 42 , stratify = stratify 
73+         )
6974    elif  split  ==  "val" :
7075        image_paths  =  image_paths [n_train :]
7176        label_paths  =  label_paths [n_train :]
@@ -75,13 +80,14 @@ def select_paths(image_paths, label_paths, split, filter_empty, random_split=Tru
7580
7681def  get_loader (root , split , patch_shape , batch_size , filter_empty , separate_folders , anisotropy ):
7782    if  separate_folders :
78-         image_paths , label_paths  =  get_image_and_label_paths_sep_folders (root )
83+         image_paths , label_paths ,  stratify  =  get_image_and_label_paths_sep_folders (root )
7984    else :
80-         image_paths , label_paths  =  get_image_and_label_paths (root )
81-     this_image_paths , this_label_paths  =  select_paths (image_paths , label_paths , split , filter_empty )
85+         image_paths , label_paths ,  stratify  =  get_image_and_label_paths (root )
86+     this_image_paths , this_label_paths  =  select_paths (image_paths , label_paths , split , filter_empty ,  stratify = stratify )
8287
8388    assert  len (this_image_paths ) ==  len (this_label_paths )
8489    assert  len (this_image_paths ) >  0 
90+     print (split , ":" , len (this_image_paths ), "image crops" )
8591
8692    if  split  ==  "train" :
8793        n_samples  =  250  *  batch_size 
@@ -133,10 +139,11 @@ def main():
133139
134140    # Parameters for training on A100. 
135141    n_iterations  =  int (1e5 )
136-     patch_shape  =  (48 , 128 , 128 )
142+     patch_shape  =  (48 , 128 , 128 )  if   anisotropy   is   None   else  ( 24 ,  128 ,  128 ) 
137143
138144    # The U-Net. 
139-     model  =  get_3d_model ()
145+     scale_factors  =  None  if  args .anisotropy  is  None  else  [[1 , 2 , 2 ], [2 , 2 , 2 ], [2 , 2 , 2 ], [2 , 2 , 2 ]]
146+     model  =  get_3d_model (scale_factors = scale_factors )
140147
141148    # Create the training loader with train and val set. 
142149    train_loader , train_images , train_labels  =  get_loader (
0 commit comments