11import argparse
22from functools import partial
33
4- from .util import run_segmentation , get_model
4+ from .util import (
5+ run_segmentation , get_model , get_model_registry , get_model_training_resolution , load_custom_model
6+ )
57from ..imod .to_imod import export_helper , write_segmentation_to_imod_as_points , write_segmentation_to_imod
68from ..inference .util import inference_helper , parse_tiling
79
810
911def imod_point_cli ():
10- parser = argparse .ArgumentParser (description = "" )
12+ parser = argparse .ArgumentParser (
13+ description = "Convert a vesicle segmentation to an IMOD point model, "
14+ "corresponding to a sphere for each vesicle in the segmentation."
15+ )
1116 parser .add_argument (
1217 "--input_path" , "-i" , required = True ,
1318 help = "The filepath to the mrc file or the directory containing the tomogram data."
1419 )
1520 parser .add_argument (
1621 "--segmentation_path" , "-s" , required = True ,
17- help = "The filepath to the tif file or the directory containing the segmentations."
22+ help = "The filepath to the file or the directory containing the segmentations."
1823 )
1924 parser .add_argument (
2025 "--output_path" , "-o" , required = True ,
2126 help = "The filepath to directory where the segmentations will be saved."
2227 )
2328 parser .add_argument (
24- "--segmentation_key" , "-k" , help = ""
29+ "--segmentation_key" , "-k" ,
30+ help = "The key in the segmentation files. If not given we assume that the segmentations are stored as tif."
31+ "If given, we assume they are stored as hdf5 files, and use the key to load the internal dataset."
2532 )
2633 parser .add_argument (
27- "--min_radius" , type = float , default = 10.0 , help = ""
34+ "--min_radius" , type = float , default = 10.0 ,
35+ help = "The minimum vesicle radius in nm. Objects that are smaller than this radius will be exclded from the export." # noqa
2836 )
2937 parser .add_argument (
30- "--radius_factor" , type = float , default = 1.0 , help = "" ,
38+ "--radius_factor" , type = float , default = 1.0 ,
39+ help = "A factor for scaling the sphere radius for the export. "
40+ "This can be used to fit the size of segmented vesicles to the best matching spheres." ,
3141 )
3242 parser .add_argument (
33- "--force" , action = "store_true" , help = "" ,
43+ "--force" , action = "store_true" ,
44+ help = "Whether to over-write already present export results."
3445 )
3546 args = parser .parse_args ()
3647
@@ -51,24 +62,29 @@ def imod_point_cli():
5162
5263
5364def imod_object_cli ():
54- parser = argparse .ArgumentParser (description = "" )
65+ parser = argparse .ArgumentParser (
66+ description = "Convert segmented objects to close contour IMOD models."
67+ )
5568 parser .add_argument (
5669 "--input_path" , "-i" , required = True ,
5770 help = "The filepath to the mrc file or the directory containing the tomogram data."
5871 )
5972 parser .add_argument (
6073 "--segmentation_path" , "-s" , required = True ,
61- help = "The filepath to the tif file or the directory containing the segmentations."
74+ help = "The filepath to the file or the directory containing the segmentations."
6275 )
6376 parser .add_argument (
6477 "--output_path" , "-o" , required = True ,
6578 help = "The filepath to directory where the segmentations will be saved."
6679 )
6780 parser .add_argument (
68- "--segmentation_key" , "-k" , help = ""
81+ "--segmentation_key" , "-k" ,
82+ help = "The key in the segmentation files. If not given we assume that the segmentations are stored as tif."
83+ "If given, we assume they are stored as hdf5 files, and use the key to load the internal dataset."
6984 )
7085 parser .add_argument (
71- "--force" , action = "store_true" , help = "" ,
86+ "--force" , action = "store_true" ,
87+ help = "Whether to over-write already present export results."
7288 )
7389 args = parser .parse_args ()
7490 export_helper (
@@ -82,8 +98,6 @@ def imod_object_cli():
8298
8399
84100# TODO: handle kwargs
85- # TODO: add custom model path
86- # TODO: enable autoscaling from input resolution
87101def segmentation_cli ():
88102 parser = argparse .ArgumentParser (description = "Run segmentation." )
89103 parser .add_argument (
@@ -94,9 +108,11 @@ def segmentation_cli():
94108 "--output_path" , "-o" , required = True ,
95109 help = "The filepath to directory where the segmentations will be saved."
96110 )
97- # TODO: list the availabel models here by parsing the keys of the model registry
111+ model_names = list (get_model_registry ().urls .keys ())
112+ model_names = ", " .join (model_names )
98113 parser .add_argument (
99- "--model" , "-m" , required = True , help = "The model type."
114+ "--model" , "-m" , required = True ,
115+ help = f"The model type. The following models are currently available: { model_names } "
100116 )
101117 parser .add_argument (
102118 "--mask_path" , help = "The filepath to a tif file with a mask that will be used to restrict the segmentation."
@@ -119,23 +135,45 @@ def segmentation_cli():
119135 "--data_ext" , default = ".mrc" , help = "The extension of the tomogram data. By default .mrc."
120136 )
121137 parser .add_argument (
122- "--segmentation_key " , "-s " , help = ""
138+ "--checkpoint " , "-c " , help = "Path to a custom model, e.g. from domain adaptation." ,
123139 )
124- # TODO enable autoscaling
125140 parser .add_argument (
126- "--scale" , type = float , default = None , help = ""
141+ "--segmentation_key" , "-s" ,
142+ help = "If given, the outputs will be saved to an hdf5 file with this key. Otherwise they will be saved as tif." ,
143+ )
144+ parser .add_argument (
145+ "--scale" , type = float ,
146+ help = "The factor for rescaling the data before inference. "
147+ "By default, the scaling factor will be derived from the voxel size of the input data. "
148+ "If this parameter is given it will over-ride the default behavior. "
127149 )
128150 args = parser .parse_args ()
129151
130- model = get_model (args .model )
131- tiling = parse_tiling (args .tile_shape , args .halo )
132- scale = None if args .scale is None else 3 * (args .scale ,)
152+ if args .checkpoint is None :
153+ model = get_model (args .model )
154+ else :
155+ model = load_custom_model (args .checkpoint )
156+ assert model is not None , f"The model from { args .checkpoint } could not be loaded."
157+
158+ is_2d = "2d" in args .model
159+ tiling = parse_tiling (args .tile_shape , args .halo , is_2d = is_2d )
160+
161+ # If the scale argument is not passed, then we get the average training resolution for the model.
162+ # The inputs will then be scaled to match this resolution based on the voxel size from the mrc files.
163+ if args .scale is None :
164+ model_resolution = get_model_training_resolution (args .model )
165+ model_resolution = tuple (model_resolution [ax ] for ax in ("yx" if is_2d else "zyx" ))
166+ scale = None
167+ # Otherwise, we set the model resolution to None and use the scaling factor provided by the user.
168+ else :
169+ model_resolution = None
170+ scale = (2 if is_2d else 3 ) * (args .scale ,)
133171
134172 segmentation_function = partial (
135- run_segmentation , model = model , model_type = args .model , verbose = False , tiling = tiling , scale = scale
173+ run_segmentation , model = model , model_type = args .model , verbose = False , tiling = tiling ,
136174 )
137175 inference_helper (
138176 args .input_path , args .output_path , segmentation_function ,
139177 mask_input_path = args .mask_path , force = args .force , data_ext = args .data_ext ,
140- output_key = args .segmentation_key ,
178+ output_key = args .segmentation_key , model_resolution = model_resolution , scale = scale ,
141179 )
0 commit comments