11import argparse
2+ import json
23import os
34from datetime import datetime
45from glob import glob
56
67import torch_em
78from flamingo_tools .training import get_supervised_loader , get_3d_model
9+ from sklearn .model_selection import train_test_split
810
911ROOT_CLUSTER = "/scratch-grete/usr/nimcpape/data/moser/lightsheet/training"
1012
@@ -54,7 +56,7 @@ def get_image_and_label_paths_sep_folders(root):
5456 return image_paths , label_paths
5557
5658
57- def select_paths (image_paths , label_paths , split , filter_empty ):
59+ def select_paths (image_paths , label_paths , split , filter_empty , random_split = True ):
5860 if filter_empty :
5961 image_paths = [imp for imp in image_paths if "empty" not in imp ]
6062 label_paths = [imp for imp in label_paths if "empty" not in imp ]
@@ -64,10 +66,13 @@ def select_paths(image_paths, label_paths, split, filter_empty):
6466 train_fraction = 0.85
6567
6668 n_train = int (train_fraction * n_files )
67- if split == "train" :
69+ if split == "train" and random_split :
70+ image_paths , _ , label_paths , _ = train_test_split (image_paths , label_paths , train_size = n_train , random_state = 42 )
71+ elif split == "train" :
6872 image_paths = image_paths [:n_train ]
6973 label_paths = label_paths [:n_train ]
70-
74+ elif split == "val" and random_split :
75+ _ , image_paths , _ , label_paths = train_test_split (image_paths , label_paths , train_size = n_train , random_state = 42 )
7176 elif split == "val" :
7277 image_paths = image_paths [n_train :]
7378 label_paths = label_paths [n_train :]
@@ -90,7 +95,11 @@ def get_loader(root, split, patch_shape, batch_size, filter_empty, separate_fold
9095 elif split == "val" :
9196 n_samples = 16 * batch_size
9297
93- return get_supervised_loader (this_image_paths , this_label_paths , patch_shape , batch_size , n_samples = n_samples )
98+ return (
99+ get_supervised_loader (this_image_paths , this_label_paths , patch_shape , batch_size , n_samples = n_samples ),
100+ this_image_paths ,
101+ this_label_paths
102+ )
94103
95104
96105def main ():
@@ -131,10 +140,10 @@ def main():
131140 model = get_3d_model ()
132141
133142 # Create the training loader with train and val set.
134- train_loader = get_loader (
143+ train_loader , train_images , train_labels = get_loader (
135144 root , "train" , patch_shape , batch_size , filter_empty = filter_empty , separate_folders = args .separate_folders
136145 )
137- val_loader = get_loader (
146+ val_loader , val_images , val_labels = get_loader (
138147 root , "val" , patch_shape , batch_size , filter_empty = filter_empty , separate_folders = args .separate_folders
139148 )
140149
@@ -146,8 +155,21 @@ def main():
146155
147156 loss = torch_em .loss .distance_based .DiceBasedDistanceLoss (mask_distances_in_bg = True )
148157
149- # Create the trainer .
158+ # Serialize the train test split .
150159 name = f"cochlea_distance_unet_{ run_name } "
160+ ckpt_folder = os .path .join ("checkpoints" , name )
161+ os .makedirs (ckpt_folder , exist_ok = True )
162+ split_file = os .path .join (ckpt_folder , "split.json" )
163+ with open (split_file , "w" ) as f :
164+ json .dump (
165+ {
166+ "train" : {"images" : train_images , "labels" : train_labels },
167+ "val" : {"images" : val_images , "labels" : val_labels },
168+ },
169+ f , sort_keys = True , indent = 2
170+ )
171+
172+ # Create the trainer.
151173 trainer = torch_em .default_segmentation_trainer (
152174 name = name ,
153175 model = model ,
0 commit comments