44from tqdm import tqdm
55from math import floor
66from itertools import product
7- from typing import Union , Tuple , Optional , List , Dict
7+ from typing import Union , Tuple , Optional , List , Dict , Literal
88
99import imageio .v3 as imageio
1010
1111import torch
1212
13- from elf .evaluation import mean_segmentation_accuracy
13+ from elf .evaluation import mean_segmentation_accuracy , dice_score
1414
1515from .. import util
1616from ..inference import batched_inference
@@ -30,7 +30,7 @@ def default_grid_search_values_multi_dimensional_segmentation(
3030 iou_threshold_values: The values for `iou_threshold` used in the grid-search.
3131 By default values in the range from 0.5 to 0.9 with a stepsize of 0.1 will be used.
3232 projection_method_values: The values for `projection` method used in the grid-search.
33- By default the values `mask`, `bounding_box` and `points ` are used.
33+ By default the values `mask`, `points`, `box`, `points_and_mask` and `single_point ` are used.
3434 box_extension_values: The values for `box_extension` used in the grid-search.
3535 By default values in the range from 0 to 0.25 with a stepsize of 0.025 will be used.
3636
@@ -71,6 +71,7 @@ def segment_slices_from_ground_truth(
7171 verbose : bool = False ,
7272 return_segmentation : bool = False ,
7373 min_size : int = 0 ,
74+ evaluation_metric : Literal ["sa" , "dice" ] = "sa" ,
7475) -> Union [float , Tuple [np .ndarray , float ]]:
7576 """Segment all objects in a volume by prompt-based segmentation in one slice per object.
7677
@@ -94,6 +95,7 @@ def segment_slices_from_ground_truth(
9495 return_segmentation: Whether to return the segmented volume.
9596 min_size: The minimal size for evaluating an object in the ground-truth.
9697 The size is measured within the central slice.
98+ evaluation_metric: The choice of supported metric to evaluate predictions.
9799 """
98100 assert volume .ndim == 3
99101
@@ -116,7 +118,7 @@ def segment_slices_from_ground_truth(
116118 _segmentation_completed = True # We avoid rerunning the segmentation if it is completed.
117119
118120 skipped_label_ids = []
119- for label_id in tqdm (label_ids , desc = "Segmenting per object in the volume" ):
121+ for label_id in tqdm (label_ids , desc = "Segmenting per object in the volume" , disable = not verbose ):
120122 # Binary label volume per instance (also referred to as object)
121123 this_seg = (ground_truth == label_id ).astype ("int" )
122124
@@ -204,43 +206,56 @@ def segment_slices_from_ground_truth(
204206 else :
205207 curr_gt = ground_truth
206208
207- msa , sa = mean_segmentation_accuracy (final_segmentation , curr_gt , return_accuracies = True )
208- results = {"mSA" : msa , "SA50" : sa [0 ], "SA75" : sa [5 ]}
209- results = pd .DataFrame .from_dict ([results ])
209+ if evaluation_metric == "sa" :
210+ msa , sa = mean_segmentation_accuracy (
211+ segmentation = final_segmentation , groundtruth = curr_gt , return_accuracies = True
212+ )
213+ results = {"mSA" : msa , "SA50" : sa [0 ], "SA75" : sa [5 ]}
214+ elif evaluation_metric == "dice" :
215+ dice = dice_score (segmentation = final_segmentation , groundtruth = curr_gt )
216+ results = {"Dice" : dice }
217+ else :
218+ raise ValueError (f"'{ evaluation_metric } ' is not a supported evaluation metrics. Please choose 'sa' / 'dice'." )
210219
211220 if return_segmentation :
212221 return results , final_segmentation
213222 else :
214223 return results
215224
216225
217- def _get_best_parameters_from_grid_search_combinations (result_dir , best_params_path , grid_search_values ):
226+ def _get_best_parameters_from_grid_search_combinations (
227+ result_dir , best_params_path , grid_search_values , evaluation_metric ,
228+ ):
218229 if os .path .exists (best_params_path ):
219230 print ("The best parameters are already saved at:" , best_params_path )
220231 return
221232
222- best_kwargs , best_msa = evaluate_instance_segmentation_grid_search (result_dir , list (grid_search_values .keys ()))
233+ criterion = "mSA" if evaluation_metric == "sa" else "Dice"
234+ best_kwargs , best_metric = evaluate_instance_segmentation_grid_search (
235+ result_dir = result_dir , grid_search_parameters = list (grid_search_values .keys ()), criterion = criterion ,
236+ )
223237
224238 # let's save the best parameters
225- best_kwargs ["mSA" ] = best_msa
239+ best_kwargs [criterion ] = best_metric
226240 best_param_df = pd .DataFrame .from_dict ([best_kwargs ])
227241 best_param_df .to_csv (best_params_path )
228242
229243 best_param_str = ", " .join (f"{ k } = { v } " for k , v in best_kwargs .items ())
230- print ("Best grid-search result:" , best_msa , "with parmeters:\n " , best_param_str )
244+ print ("Best grid-search result:" , best_metric , "with parmeters:\n " , best_param_str )
231245
232246
233247def run_multi_dimensional_segmentation_grid_search (
234248 volume : np .ndarray ,
235249 ground_truth : np .ndarray ,
236250 model_type : str ,
237251 checkpoint_path : Union [str , os .PathLike ],
238- embedding_path : Union [str , os .PathLike ],
252+ embedding_path : Optional [ Union [str , os .PathLike ] ],
239253 result_dir : Union [str , os .PathLike ],
240254 interactive_seg_mode : str = "box" ,
241255 verbose : bool = False ,
242256 grid_search_values : Optional [Dict [str , List ]] = None ,
243- min_size : int = 0
257+ min_size : int = 0 ,
258+ evaluation_metric : Literal ["sa" , "dice" ] = "sa" ,
244259):
245260 """Run grid search for prompt-based multi-dimensional instance segmentation.
246261
@@ -250,7 +265,7 @@ def run_multi_dimensional_segmentation_grid_search(
250265 ```
251266 grid_search_values = {
252267 "iou_threshold": [0.5, 0.6, 0.7, 0.8, 0.9],
253- "projection": ["mask", "bounding_box ", "points"],
268+ "projection": ["mask", "box ", "points"],
254269 "box_extension": [0, 0.1, 0.2, 0.3, 0.4, 0,5],
255270 }
256271 ```
@@ -264,12 +279,13 @@ def run_multi_dimensional_segmentation_grid_search(
264279 model_type: Choice of segment anything model.
265280 checkpoint_path: Path to the model checkpoint.
266281 embedding_path: Path to cache the computed embeddings.
267- result_path : Path to save the grid search results.
282+ result_dir : Path to save the grid search results.
268283 interactive_seg_mode: Method for guiding prompt-based instance segmentation.
269284 verbose: Whether to get the trace for projected segmentations.
270285 grid_search_values: The grid search values for parameters of the `segment_slices_from_ground_truth` function.
271286 min_size: The minimal size for evaluating an object in the ground-truth.
272287 The size is measured within the central slice.
288+ evaluation_metric: The choice of metric for evaluating predictions.
273289 """
274290 if grid_search_values is None :
275291 grid_search_values = default_grid_search_values_multi_dimensional_segmentation ()
@@ -280,7 +296,9 @@ def run_multi_dimensional_segmentation_grid_search(
280296 result_path = os .path .join (result_dir , "all_grid_search_results.csv" )
281297 best_params_path = os .path .join (result_dir , "grid_search_params_multi_dimensional_segmentation.csv" )
282298 if os .path .exists (result_path ):
283- _get_best_parameters_from_grid_search_combinations (result_dir , best_params_path , grid_search_values )
299+ _get_best_parameters_from_grid_search_combinations (
300+ result_dir , best_params_path , grid_search_values , evaluation_metric
301+ )
284302 return best_params_path
285303
286304 # Compute all combinations of grid search values.
@@ -292,7 +310,7 @@ def run_multi_dimensional_segmentation_grid_search(
292310 ]
293311
294312 net_list = []
295- for gs_kwargs in tqdm (gs_combinations ):
313+ for gs_kwargs in tqdm (gs_combinations , desc = "Run grid-search for multi-dimensional segmentation" ):
296314 results = segment_slices_from_ground_truth (
297315 volume = volume ,
298316 ground_truth = ground_truth ,
@@ -303,6 +321,7 @@ def run_multi_dimensional_segmentation_grid_search(
303321 verbose = verbose ,
304322 return_segmentation = False ,
305323 min_size = min_size ,
324+ evaluation_metric = evaluation_metric ,
306325 ** gs_kwargs
307326 )
308327
@@ -313,6 +332,8 @@ def run_multi_dimensional_segmentation_grid_search(
313332 res_df = pd .concat (net_list , ignore_index = True )
314333 res_df .to_csv (result_path )
315334
316- _get_best_parameters_from_grid_search_combinations (result_dir , best_params_path , grid_search_values )
335+ _get_best_parameters_from_grid_search_combinations (
336+ result_dir , best_params_path , grid_search_values , evaluation_metric
337+ )
317338 print ("The best grid-search parameters have been computed and stored at:" , best_params_path )
318339 return best_params_path
0 commit comments