Skip to content

Commit dd020bd

Browse files
committed
Enable window + default thresholds
1 parent 89eb46c commit dd020bd

File tree

8 files changed

+30
-5
lines changed

8 files changed

+30
-5
lines changed

napari_cellseg3d/code_models/models/TEMPLATE_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
33
Please note that custom model implementations are not fully supported out of the box yet, but might be in the future.
44
"""
5+
56
from abc import ABC, abstractmethod
67

78

@@ -11,6 +12,7 @@ class ModelTemplate_(ABC):
1112
weights_file = (
1213
"model_template.pth" # specify the file name of the weights file only
1314
)
15+
default_threshold = 0.5 # specify the default threshold for the model
1416

1517
@abstractmethod
1618
def __init__(

napari_cellseg3d/code_models/models/model_SegResNet.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
"""SegResNet wrapper for napari_cellseg3d."""
2+
23
from monai.networks.nets import SegResNetVAE
34

45

56
class SegResNet_(SegResNetVAE):
67
"""SegResNet_ wrapper for napari_cellseg3d."""
78

89
weights_file = "SegResNet_latest.pth"
10+
default_threshold = 0.3
911

1012
def __init__(
1113
self, input_img_size, out_channels=1, dropout_prob=0.3, **kwargs

napari_cellseg3d/code_models/models/model_SwinUNetR.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""SwinUNetR wrapper for napari_cellseg3d."""
2+
23
from monai.networks.nets import SwinUNETR
34

45
from napari_cellseg3d.utils import LOGGER
@@ -10,6 +11,7 @@ class SwinUNETR_(SwinUNETR):
1011
"""SwinUNETR wrapper for napari_cellseg3d."""
1112

1213
weights_file = "SwinUNetR_latest.pth"
14+
default_threshold = 0.4
1315

1416
def __init__(
1517
self,

napari_cellseg3d/code_models/models/model_TRAILMAP_MS.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""TRAILMAP model, reimplemented in PyTorch."""
2+
23
from napari_cellseg3d.code_models.models.unet.model import UNet3D
34
from napari_cellseg3d.utils import LOGGER as logger
45

@@ -7,6 +8,7 @@ class TRAILMAP_MS_(UNet3D):
78
"""TRAILMAP_MS wrapper for napari_cellseg3d."""
89

910
weights_file = "TRAILMAP_MS_best_metric.pth"
11+
default_threshold = 0.15
1012

1113
# original model from Liqun Luo lab, transferred to pytorch and trained on mesoSPIM-acquired data (mostly TPH2 as of July 2022)
1214

napari_cellseg3d/code_models/models/model_VNet.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
"""VNet wrapper for napari_cellseg3d."""
2+
23
from monai.networks.nets import VNet
34

45

56
class VNet_(VNet):
67
"""VNet wrapper for napari_cellseg3d."""
78

89
weights_file = "VNet_latest.pth"
10+
default_threshold = 0.15
911

1012
def __init__(self, in_channels=1, out_channels=1, **kwargs):
1113
"""Create a VNet model.

napari_cellseg3d/code_models/models/model_WNet.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ class WNet_(WNet_encoder):
1515
"""
1616

1717
weights_file = "wnet_latest.pth"
18+
default_threshold = 0.6
1819

1920
def __init__(
2021
self,

napari_cellseg3d/code_models/models/model_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Model for testing purposes."""
2+
23
import torch
34
from torch import nn
45

@@ -7,6 +8,7 @@ class TestModel(nn.Module):
78
"""For tests only."""
89

910
weights_file = "test.pth"
11+
default_threshold = 0.5
1012

1113
def __init__(self, **kwargs):
1214
"""Create a TestModel model."""

napari_cellseg3d/code_plugins/plugin_model_inference.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -143,11 +143,11 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None):
143143
self.thresholding_checkbox = ui.CheckBox(
144144
"Perform thresholding", self._toggle_display_thresh
145145
)
146-
self.thresholding_checkbox.setChecked(True)
147146

148147
self.thresholding_slider = ui.Slider(
149-
default=config.PostProcessConfig().thresholding.threshold_value
150-
* 100,
148+
default=config.MODEL_LIST[
149+
self.model_choice.currentText()
150+
].default_threshold,
151151
divide_factor=100.0,
152152
parent=self,
153153
)
@@ -410,6 +410,13 @@ def _load_weights_path(self):
410410
)
411411
self._update_weights_path(file)
412412

413+
def _set_default_threshold(self):
414+
# Whenever a model is selected, set the default threshold from the model file
415+
model_name = self.model_choice.currentText()
416+
threshold = config.MODEL_LIST[model_name].default_threshold
417+
print(threshold)
418+
self.thresholding_slider.slider_value = threshold * 100
419+
413420
def _build(self):
414421
"""Puts all widgets in a layout and adds them to the napari Viewer."""
415422
# ui.add_blank(self.view_results_container, view_results_layout)
@@ -494,7 +501,8 @@ def _build(self):
494501
self.device_choice,
495502
],
496503
)
497-
self.window_infer_params.setVisible(False)
504+
self.use_window_choice.setChecked(True)
505+
# self.window_infer_params.setVisible(False)
498506

499507
inference_param_group_w.setLayout(inference_param_group_l)
500508

@@ -539,14 +547,18 @@ def _build(self):
539547
# self.instance_param_container, # instance segmentation
540548
],
541549
)
550+
# self.thresholding_slider.container.setVisible(False)
551+
self.thresholding_checkbox.setChecked(True)
542552
self._toggle_crf_choice()
543553
self.model_choice.currentIndexChanged.connect(self._toggle_crf_choice)
554+
self.model_choice.currentIndexChanged.connect(
555+
self._set_default_threshold
556+
)
544557
ModelFramework._show_io_element(
545558
self.save_stats_to_csv_box, self.use_instance_choice
546559
)
547560

548561
self.anisotropy_wdgt.container.setVisible(False)
549-
self.thresholding_slider.container.setVisible(False)
550562
self.instance_widgets.setVisible(False)
551563
self.crf_widgets.setVisible(False)
552564
self.save_stats_to_csv_box.setVisible(False)

0 commit comments

Comments
 (0)