88
99from synapse_net .inference .vesicles import segment_vesicles
1010from synapse_net .inference .util import parse_tiling
11+ from synapse_net .inference .inference import get_model_path
1112
1213def _require_output_folders (output_folder ):
1314 #seg_output = os.path.join(output_folder, "segmentations")
@@ -34,7 +35,7 @@ def get_volume(input_path):
3435 input_volume = f [key ][:]
3536 return input_volume
3637
37- def run_vesicle_segmentation (input_path , output_path , model_path , mask_path , mask_key ,tile_shape , halo , include_boundary , key_label ):
38+ def run_vesicle_segmentation (input_path , output_path , mask_path , mask_key ,tile_shape , halo , include_boundary , key_label , model_path = None , save_pred = False ):
3839 tiling = parse_tiling (tile_shape , halo )
3940 print (f"using tiling { tiling } " )
4041 input = get_volume (input_path )
@@ -45,7 +46,10 @@ def run_vesicle_segmentation(input_path, output_path, model_path, mask_path, mas
4546 mask = f [mask_key ][:]
4647 else :
4748 mask = None
48-
49+
50+ if model_path is None :
51+ model_path = get_model_path ("vesicles_3d" )
52+
4953 segmentation , prediction = segment_vesicles (input_volume = input , model_path = model_path , verbose = False , tiling = tiling , return_predictions = True , exclude_boundary = not include_boundary , mask = mask )
5054 foreground , boundaries = prediction [:2 ]
5155
@@ -63,14 +67,15 @@ def run_vesicle_segmentation(input_path, output_path, model_path, mask_path, mas
6367 else :
6468 f .create_dataset ("raw" , data = input , compression = "gzip" )
6569
66- key = f"vesicles/segment_from_ { key_label } "
70+ key = f"predictions/ { key_label } "
6771 if key in f :
6872 print ("Skipping" , input_path , "because" , key , "exists" )
6973 else :
7074 f .create_dataset (key , data = segmentation , compression = "gzip" )
71- f .create_dataset (f"prediction_{ key_label } /foreground" , data = foreground , compression = "gzip" )
72- f .create_dataset (f"prediction_{ key_label } /boundaries" , data = boundaries , compression = "gzip" )
73-
75+ if save_pred :
76+ f .create_dataset (f"prediction_{ key_label } /foreground" , data = foreground , compression = "gzip" )
77+ f .create_dataset (f"prediction_{ key_label } /boundaries" , data = boundaries , compression = "gzip" )
78+
7479 if mask is not None :
7580 if mask_key in f :
7681 print ("mask image already saved" )
@@ -97,7 +102,7 @@ def segment_folder(args):
97102 print (f"Mask file not found for { input_path } " )
98103 mask_path = None
99104
100- run_vesicle_segmentation (input_path , args .output_path , args . model_path , mask_path , args .mask_key , args .tile_shape , args .halo , args .include_boundary , args .key_label )
105+ run_vesicle_segmentation (input_path , args .output_path , mask_path , args .mask_key , args .tile_shape , args .halo , args .include_boundary , args .key_label , args . model_path , args . save_pred )
101106
102107def main ():
103108 parser = argparse .ArgumentParser (description = "Segment vesicles in EM tomograms." )
@@ -110,7 +115,7 @@ def main():
110115 help = "The filepath to directory where the segmentations will be saved."
111116 )
112117 parser .add_argument (
113- "--model_path" , "-m" , required = True , help = "The filepath to the vesicle model."
118+ "--model_path" , "-m" , help = "The filepath to the vesicle model."
114119 )
115120 parser .add_argument (
116121 "--mask_path" , help = "The filepath to a h5 file with a mask that will be used to restrict the segmentation. Needs to be in combination with mask_key."
@@ -131,17 +136,21 @@ def main():
131136 help = "Include vesicles that touch the top / bottom of the tomogram. By default these are excluded."
132137 )
133138 parser .add_argument (
134- "--key_label" , "-k" , default = "combined_vesicles " ,
139+ "--key_label" , "-k" , default = "vesicle_seg " ,
135140 help = "Give the key name for saving the segmentation in h5."
136141 )
142+ parser .add_argument (
143+ "--save_pred" , action = "store_true" ,
144+ help = "If set to true the prediction is also saved."
145+ )
137146 args = parser .parse_args ()
138147
139148 input_ = args .input_path
140149
141150 if os .path .isdir (input_ ):
142151 segment_folder (args )
143152 else :
144- run_vesicle_segmentation (input_ , args .output_path , args .model_path , args . mask_path , args .mask_key , args .tile_shape , args .halo , args .include_boundary , args .key_label )
153+ run_vesicle_segmentation (input_ , args .output_path , args .mask_path , args .mask_key , args .tile_shape , args .halo , args .include_boundary , args .key_label , args . model_path , args . save_pred )
145154
146155 print ("Finished segmenting!" )
147156
0 commit comments