Skip to content

Commit e283baf

Browse files
committed
Change threshold for Swin
1 parent ccc693f commit e283baf

File tree

3 files changed

+17
-4
lines changed

3 files changed

+17
-4
lines changed

napari_cellseg3d/code_models/worker_inference.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from monai.inferers import sliding_window_inference
1010
from monai.transforms import (
1111
AddChannel,
12-
AsDiscrete,
12+
# AsDiscrete,
1313
Compose,
1414
EnsureChannelFirstd,
1515
EnsureType,
@@ -35,6 +35,7 @@
3535
QuantileNormalization,
3636
QuantileNormalizationd,
3737
RemapTensor,
38+
Threshold,
3839
WeightsDownloader,
3940
)
4041

@@ -727,7 +728,8 @@ def inference(self):
727728
post_process_transforms = Compose(
728729
[
729730
RemapTensor(new_max=1.0, new_min=0.0),
730-
AsDiscrete(threshold=t),
731+
# AsDiscrete(threshold=t),
732+
Threshold(threshold=t),
731733
EnsureType(),
732734
]
733735
)

napari_cellseg3d/code_models/worker_training.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from monai.inferers import sliding_window_inference
1818
from monai.metrics import DiceMetric
1919
from monai.transforms import (
20-
AsDiscrete,
20+
# AsDiscrete,
2121
Compose,
2222
EnsureChannelFirstd,
2323
EnsureType,
@@ -45,6 +45,7 @@
4545
LogSignal,
4646
QuantileNormalizationd,
4747
RemapTensor,
48+
Threshold,
4849
TrainingReport,
4950
WeightsDownloader,
5051
)
@@ -638,10 +639,11 @@ def get_loader_func(num_samples):
638639
labs = decollate_batch(val_labels)
639640

640641
# TODO : more parameters/flexibility
642+
641643
post_pred = Compose(
642644
[
643645
RemapTensor(new_max=1, new_min=0),
644-
AsDiscrete(threshold=0.25), # needed ?
646+
Threshold(threshold=0.5),
645647
EnsureType(),
646648
]
647649
) #

napari_cellseg3d/code_models/workers_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,15 @@ def __call__(self, img):
213213
return utils.remap_image(img, new_max=self.max, new_min=self.min)
214214

215215

216+
class Threshold(Transform):
217+
def __init__(self, threshold=0.5):
218+
super().__init__()
219+
self.threshold = threshold
220+
221+
def __call__(self, img):
222+
return torch.where(img > self.threshold, 1, 0)
223+
224+
216225
@dataclass
217226
class InferenceResult:
218227
"""Class to record results of a segmentation job"""

0 commit comments

Comments
 (0)