99
1010from synapse_net .training import supervised_training , AZDistanceLabelTransform
1111
12- TRAIN_ROOT = "/mnt/ceph-hdd/cold_store/projects /nim00007/new_AZ_train_data"
12+ TRAIN_ROOT = "/mnt/ceph-hdd/cold /nim00007/new_AZ_train_data"
1313OUTPUT_ROOT = "./models_az_thin"
1414
1515
1616def _require_train_val_test_split (datasets ):
17- train_ratio , val_ratio , test_ratio = 0.70 , 0.1 , 0.2
17+ train_ratio , val_ratio , test_ratio = 0.60 , 0.2 , 0.2
1818
1919 def _train_val_test_split (names ):
2020 train , test = train_test_split (names , test_size = 1 - train_ratio , shuffle = True )
@@ -87,17 +87,22 @@ def train(key, ignore_label=None, use_distances=False, training_2D=False, testse
8787
8888 os .makedirs (OUTPUT_ROOT , exist_ok = True )
8989
90- datasets = ["tem" , "chemical_fixation" , "stem" , "stem_cropped" , "endbulb_of_held" , "endbulb_of_held_cropped" ]
91- train_paths = get_paths ("train" , datasets = datasets , testset = testset )
92- val_paths = get_paths ("val" , datasets = datasets , testset = testset )
90+ datasets_with_testset_true = ["tem" , "chemical_fixation" , "stem" , "endbulb_of_held" ]
91+ datasets_with_testset_false = ["stem_cropped" , "endbulb_of_held_cropped" ]
92+
93+ train_paths = get_paths ("train" , datasets = datasets_with_testset_true , testset = True )
94+ val_paths = get_paths ("val" , datasets = datasets_with_testset_true , testset = True )
95+
96+ train_paths += get_paths ("train" , datasets = datasets_with_testset_false , testset = False )
97+ val_paths += get_paths ("val" , datasets = datasets_with_testset_false , testset = False )
9398
9499 print ("Start training with:" )
95100 print (len (train_paths ), "tomograms for training" )
96101 print (len (val_paths ), "tomograms for validation" )
97102
98103 # patch_shape = [48, 256, 256]
99104 patch_shape = [48 , 384 , 384 ]
100- model_name = "v6 "
105+ model_name = "v7 "
101106
102107 # checking for 2D training
103108 if training_2D :
@@ -121,7 +126,7 @@ def train(key, ignore_label=None, use_distances=False, training_2D=False, testse
121126 sampler = torch_em .data .sampler .MinInstanceSampler (min_num_instances = 1 , p_reject = 0.85 ),
122127 n_samples_train = None , n_samples_val = 100 ,
123128 check = check ,
124- save_root = OUTPUT_ROOT ,
129+ save_root = "/mnt/lustre-emmy-hdd/usr/u12095/synapse_net/models/ConstantinAZ" ,
125130 n_iterations = int (2e5 ),
126131 ignore_label = ignore_label ,
127132 label_transform = label_transform ,
0 commit comments