Skip to content

Commit 91dc3f7

Browse files
authored
Allow flexible choice between metrics for grid-search (#812)
1 parent 86f007a commit 91dc3f7

File tree

3 files changed

+43
-24
lines changed

3 files changed

+43
-24
lines changed

micro_sam/evaluation/instance_segmentation.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -309,9 +309,7 @@ def run_instance_segmentation_inference(
309309

310310

311311
def evaluate_instance_segmentation_grid_search(
312-
result_dir: Union[str, os.PathLike],
313-
grid_search_parameters: List[str],
314-
criterion: str = "mSA"
312+
result_dir: Union[str, os.PathLike], grid_search_parameters: List[str], criterion: str = "mSA"
315313
) -> Tuple[Dict[str, Any], float]:
316314
"""Evaluate gridsearch results.
317315
@@ -324,7 +322,6 @@ def evaluate_instance_segmentation_grid_search(
324322
The best parameter setting.
325323
The evaluation score for the best setting.
326324
"""
327-
328325
# Load all the grid search results.
329326
gs_files = glob(os.path.join(result_dir, "*.csv"))
330327
gs_result = pd.concat([pd.read_csv(gs_file) for gs_file in gs_files])

micro_sam/evaluation/multi_dimensional_segmentation.py

Lines changed: 39 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44
from tqdm import tqdm
55
from math import floor
66
from itertools import product
7-
from typing import Union, Tuple, Optional, List, Dict
7+
from typing import Union, Tuple, Optional, List, Dict, Literal
88

99
import imageio.v3 as imageio
1010

1111
import torch
1212

13-
from elf.evaluation import mean_segmentation_accuracy
13+
from elf.evaluation import mean_segmentation_accuracy, dice_score
1414

1515
from .. import util
1616
from ..inference import batched_inference
@@ -30,7 +30,7 @@ def default_grid_search_values_multi_dimensional_segmentation(
3030
iou_threshold_values: The values for `iou_threshold` used in the grid-search.
3131
By default values in the range from 0.5 to 0.9 with a stepsize of 0.1 will be used.
3232
projection_method_values: The values for `projection` method used in the grid-search.
33-
By default the values `mask`, `bounding_box` and `points` are used.
33+
By default the values `mask`, `points`, `box`, `points_and_mask` and `single_point` are used.
3434
box_extension_values: The values for `box_extension` used in the grid-search.
3535
By default values in the range from 0 to 0.25 with a stepsize of 0.025 will be used.
3636
@@ -71,6 +71,7 @@ def segment_slices_from_ground_truth(
7171
verbose: bool = False,
7272
return_segmentation: bool = False,
7373
min_size: int = 0,
74+
evaluation_metric: Literal["sa", "dice"] = "sa",
7475
) -> Union[float, Tuple[np.ndarray, float]]:
7576
"""Segment all objects in a volume by prompt-based segmentation in one slice per object.
7677
@@ -94,6 +95,7 @@ def segment_slices_from_ground_truth(
9495
return_segmentation: Whether to return the segmented volume.
9596
min_size: The minimal size for evaluating an object in the ground-truth.
9697
The size is measured within the central slice.
98+
evaluation_metric: The choice of supported metric to evaluate predictions.
9799
"""
98100
assert volume.ndim == 3
99101

@@ -116,7 +118,7 @@ def segment_slices_from_ground_truth(
116118
_segmentation_completed = True # We avoid rerunning the segmentation if it is completed.
117119

118120
skipped_label_ids = []
119-
for label_id in tqdm(label_ids, desc="Segmenting per object in the volume"):
121+
for label_id in tqdm(label_ids, desc="Segmenting per object in the volume", disable=not verbose):
120122
# Binary label volume per instance (also referred to as object)
121123
this_seg = (ground_truth == label_id).astype("int")
122124

@@ -204,43 +206,56 @@ def segment_slices_from_ground_truth(
204206
else:
205207
curr_gt = ground_truth
206208

207-
msa, sa = mean_segmentation_accuracy(final_segmentation, curr_gt, return_accuracies=True)
208-
results = {"mSA": msa, "SA50": sa[0], "SA75": sa[5]}
209-
results = pd.DataFrame.from_dict([results])
209+
if evaluation_metric == "sa":
210+
msa, sa = mean_segmentation_accuracy(
211+
segmentation=final_segmentation, groundtruth=curr_gt, return_accuracies=True
212+
)
213+
results = {"mSA": msa, "SA50": sa[0], "SA75": sa[5]}
214+
elif evaluation_metric == "dice":
215+
dice = dice_score(segmentation=final_segmentation, groundtruth=curr_gt)
216+
results = {"Dice": dice}
217+
else:
218+
raise ValueError(f"'{evaluation_metric}' is not a supported evaluation metrics. Please choose 'sa' / 'dice'.")
210219

211220
if return_segmentation:
212221
return results, final_segmentation
213222
else:
214223
return results
215224

216225

217-
def _get_best_parameters_from_grid_search_combinations(result_dir, best_params_path, grid_search_values):
226+
def _get_best_parameters_from_grid_search_combinations(
227+
result_dir, best_params_path, grid_search_values, evaluation_metric,
228+
):
218229
if os.path.exists(best_params_path):
219230
print("The best parameters are already saved at:", best_params_path)
220231
return
221232

222-
best_kwargs, best_msa = evaluate_instance_segmentation_grid_search(result_dir, list(grid_search_values.keys()))
233+
criterion = "mSA" if evaluation_metric == "sa" else "Dice"
234+
best_kwargs, best_metric = evaluate_instance_segmentation_grid_search(
235+
result_dir=result_dir, grid_search_parameters=list(grid_search_values.keys()), criterion=criterion,
236+
)
223237

224238
# let's save the best parameters
225-
best_kwargs["mSA"] = best_msa
239+
best_kwargs[criterion] = best_metric
226240
best_param_df = pd.DataFrame.from_dict([best_kwargs])
227241
best_param_df.to_csv(best_params_path)
228242

229243
best_param_str = ", ".join(f"{k} = {v}" for k, v in best_kwargs.items())
230-
print("Best grid-search result:", best_msa, "with parmeters:\n", best_param_str)
244+
print("Best grid-search result:", best_metric, "with parmeters:\n", best_param_str)
231245

232246

233247
def run_multi_dimensional_segmentation_grid_search(
234248
volume: np.ndarray,
235249
ground_truth: np.ndarray,
236250
model_type: str,
237251
checkpoint_path: Union[str, os.PathLike],
238-
embedding_path: Union[str, os.PathLike],
252+
embedding_path: Optional[Union[str, os.PathLike]],
239253
result_dir: Union[str, os.PathLike],
240254
interactive_seg_mode: str = "box",
241255
verbose: bool = False,
242256
grid_search_values: Optional[Dict[str, List]] = None,
243-
min_size: int = 0
257+
min_size: int = 0,
258+
evaluation_metric: Literal["sa", "dice"] = "sa",
244259
):
245260
"""Run grid search for prompt-based multi-dimensional instance segmentation.
246261
@@ -250,7 +265,7 @@ def run_multi_dimensional_segmentation_grid_search(
250265
```
251266
grid_search_values = {
252267
"iou_threshold": [0.5, 0.6, 0.7, 0.8, 0.9],
253-
"projection": ["mask", "bounding_box", "points"],
268+
"projection": ["mask", "box", "points"],
254269
"box_extension": [0, 0.1, 0.2, 0.3, 0.4, 0,5],
255270
}
256271
```
@@ -264,12 +279,13 @@ def run_multi_dimensional_segmentation_grid_search(
264279
model_type: Choice of segment anything model.
265280
checkpoint_path: Path to the model checkpoint.
266281
embedding_path: Path to cache the computed embeddings.
267-
result_path: Path to save the grid search results.
282+
result_dir: Path to save the grid search results.
268283
interactive_seg_mode: Method for guiding prompt-based instance segmentation.
269284
verbose: Whether to get the trace for projected segmentations.
270285
grid_search_values: The grid search values for parameters of the `segment_slices_from_ground_truth` function.
271286
min_size: The minimal size for evaluating an object in the ground-truth.
272287
The size is measured within the central slice.
288+
evaluation_metric: The choice of metric for evaluating predictions.
273289
"""
274290
if grid_search_values is None:
275291
grid_search_values = default_grid_search_values_multi_dimensional_segmentation()
@@ -280,7 +296,9 @@ def run_multi_dimensional_segmentation_grid_search(
280296
result_path = os.path.join(result_dir, "all_grid_search_results.csv")
281297
best_params_path = os.path.join(result_dir, "grid_search_params_multi_dimensional_segmentation.csv")
282298
if os.path.exists(result_path):
283-
_get_best_parameters_from_grid_search_combinations(result_dir, best_params_path, grid_search_values)
299+
_get_best_parameters_from_grid_search_combinations(
300+
result_dir, best_params_path, grid_search_values, evaluation_metric
301+
)
284302
return best_params_path
285303

286304
# Compute all combinations of grid search values.
@@ -292,7 +310,7 @@ def run_multi_dimensional_segmentation_grid_search(
292310
]
293311

294312
net_list = []
295-
for gs_kwargs in tqdm(gs_combinations):
313+
for gs_kwargs in tqdm(gs_combinations, desc="Run grid-search for multi-dimensional segmentation"):
296314
results = segment_slices_from_ground_truth(
297315
volume=volume,
298316
ground_truth=ground_truth,
@@ -303,6 +321,7 @@ def run_multi_dimensional_segmentation_grid_search(
303321
verbose=verbose,
304322
return_segmentation=False,
305323
min_size=min_size,
324+
evaluation_metric=evaluation_metric,
306325
**gs_kwargs
307326
)
308327

@@ -313,6 +332,8 @@ def run_multi_dimensional_segmentation_grid_search(
313332
res_df = pd.concat(net_list, ignore_index=True)
314333
res_df.to_csv(result_path)
315334

316-
_get_best_parameters_from_grid_search_combinations(result_dir, best_params_path, grid_search_values)
335+
_get_best_parameters_from_grid_search_combinations(
336+
result_dir, best_params_path, grid_search_values, evaluation_metric
337+
)
317338
print("The best grid-search parameters have been computed and stored at:", best_params_path)
318339
return best_params_path

micro_sam/multi_dimensional_segmentation.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,9 @@ def segment_range(z_start, z_stop, increment, stopping_criterion, threshold=None
139139
if threshold is not None:
140140
iou = util.compute_iou(seg_prev, seg_z)
141141
if iou < threshold:
142-
msg = f"Segmentation stopped at slice {z} due to IOU {iou} < {threshold}."
143-
print(msg)
142+
if verbose:
143+
msg = f"Segmentation stopped at slice {z} due to IOU {iou} < {threshold}."
144+
print(msg)
144145
break
145146

146147
segmentation[z] = seg_z

0 commit comments

Comments
 (0)