Skip to content

Commit 173644a

Browse files
committed
Add grid search utility
1 parent 12916f9 commit 173644a

File tree

3 files changed

+139
-1
lines changed

3 files changed

+139
-1
lines changed

napari_cellseg3d/code_models/instance_segmentation.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,6 @@ def binary_connected(
332332
return remove_small_objects(seg, thres_small)
333333

334334

335-
336335
def binary_watershed(
337336
volume,
338337
thres_objects=0.3,

napari_cellseg3d/code_plugins/plugin_convert.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Several image processing utilities."""
2+
23
from pathlib import Path
34
from warnings import warn
45

@@ -837,3 +838,137 @@ def _start(self):
837838
)
838839
else:
839840
logger.warning("Please specify a layer or a folder")
841+
842+
843+
class ThresholdGridSearchUtils(BasePluginUtils):
844+
"""Widget to run a grid search for thresholding."""
845+
846+
save_path = Path.home() / "cellseg3d" / "threshold_grid_search"
847+
848+
def __init__(self, viewer: "napari.viewer.Viewer", parent=None):
849+
"""Creates a ThresholdGridSearchUtils widget.
850+
851+
Args:
852+
viewer: viewer in which to process data
853+
parent: parent widget
854+
"""
855+
super().__init__(
856+
viewer,
857+
parent=parent,
858+
)
859+
self.do_binarize = False
860+
self.result_text = ""
861+
self.values = {}
862+
863+
self.data_panel = self._build_io_panel()
864+
# disable folder choice
865+
self.radio_buttons.setVisible(False)
866+
self.radio_buttons.setEnabled(False)
867+
868+
self.image_layer_loader.layer_list.label.setText("Prediction :")
869+
self.label_layer_loader.layer_list.label.setText("Labels :")
870+
871+
self.results_path = str(self.save_path)
872+
self.results_filewidget.text_field.setText(self.results_path)
873+
self.results_filewidget.check_ready()
874+
875+
self.start_btn = ui.Button("Start", self._start)
876+
self.result_display = ui.make_label(self.result_text, self)
877+
self.image_layer_loader.layer_list.currentIndexChanged.connect(
878+
self._reset
879+
)
880+
self.label_layer_loader.layer_list.currentIndexChanged.connect(
881+
self._reset
882+
)
883+
884+
self.container = self._build()
885+
886+
def _build(self):
887+
container = ui.ContainerWidget()
888+
889+
container.layout.addWidget(self.data_panel)
890+
ui.add_widgets(
891+
container.layout,
892+
[
893+
self.start_btn,
894+
self.result_display,
895+
],
896+
)
897+
898+
ui.ScrollArea.make_scrollable(
899+
container.layout, self, max_wh=[MAX_W, MAX_H]
900+
)
901+
self._set_io_visibility()
902+
return container
903+
904+
def _reset(self):
905+
self.values = {}
906+
self.result_text = ""
907+
self.result_display.setText("")
908+
909+
def _check_ready(self):
910+
image_data = self.image_layer_loader.layer_data()
911+
label_data = self.label_layer_loader.layer_data()
912+
if image_data is None:
913+
self.result_display.setText("Please load a prediction layer")
914+
return False
915+
if label_data is None:
916+
self.result_display.setText("Please load a labels layer")
917+
return False
918+
if label_data.shape != image_data.shape:
919+
self.result_display.setText(
920+
"Prediction and labels must have the same shape"
921+
)
922+
return False
923+
if (
924+
label_data.min() < 0
925+
or label_data.max() > 1
926+
or len(np.unique(label_data)) != 2
927+
):
928+
self.do_binarize = True
929+
return True
930+
931+
def _get_dice_graph(self):
932+
max_dice = max(self.values.values())
933+
self.result_text += "Thre | Dice | Graph\n"
934+
for tr, dice in self.values.items():
935+
bar = "°" * int((dice / max_dice) * 25)
936+
self.result_text += f"{tr:.2f} | {dice:.3f} | {bar}\n"
937+
938+
def _start(self):
939+
utils.mkdir_from_str(self.results_path)
940+
if not self._check_ready():
941+
return
942+
943+
pred_data = self.image_layer_loader.layer_data().copy()
944+
label_data = self.label_layer_loader.layer_data().copy()
945+
if self.do_binarize:
946+
logger.info("Labels values are not binary, binarizing")
947+
label_data = to_semantic(label_data)
948+
# find best threshold
949+
search_space = np.arange(0, 1, 0.05)
950+
for i in search_space:
951+
i = i.round(2)
952+
binarized = threshold(pred_data, i)
953+
binarized = np.where(binarized > 0, 1, 0)
954+
dice = utils.dice_coeff(binarized, label_data).round(3)
955+
self.values[i] = dice
956+
logger.info(f"Threshold : {i}, Dice : {dice}")
957+
958+
best_threshold = max(self.values, key=self.values.get)
959+
binarized = threshold(pred_data, best_threshold)
960+
utils.save_layer(
961+
self.results_path,
962+
f"binarized_{utils.get_date_time()}.tif",
963+
binarized,
964+
)
965+
self.layer = utils.show_result(
966+
self._viewer,
967+
self.image_layer_loader.layer(),
968+
binarized,
969+
"binarized",
970+
existing_layer=self.layer,
971+
)
972+
self.result_test = f"Best threshold : {best_threshold}, Dice : {self.values[best_threshold]}\n"
973+
self._get_dice_graph()
974+
self.result_display.setText(self.result_text)

napari_cellseg3d/code_plugins/plugin_utilities.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Central plugin for all utilities."""
2+
23
from typing import TYPE_CHECKING
34

45
if TYPE_CHECKING:
@@ -18,6 +19,7 @@
1819
FragmentUtils,
1920
RemoveSmallUtils,
2021
StatsUtils,
22+
ThresholdGridSearchUtils,
2123
ThresholdUtils,
2224
ToInstanceUtils,
2325
ToSemanticUtils,
@@ -39,6 +41,7 @@
3941
"CRF": CRFWidget,
4042
"Label statistics": StatsUtils,
4143
"Clear large labels": ArtifactRemovalUtils,
44+
"Find best threshold": ThresholdGridSearchUtils,
4245
}
4346

4447

@@ -62,6 +65,7 @@ def __init__(self, viewer: "napari.viewer.Viewer"):
6265
"crf",
6366
"stats",
6467
"artifacts",
68+
"find_thresh",
6569
]
6670
self._create_utils_widgets(attr_names)
6771
self.utils_choice = ui.DropdownMenu(

0 commit comments

Comments
 (0)