File tree Expand file tree Collapse file tree 3 files changed +17
-4
lines changed
napari_cellseg3d/code_models Expand file tree Collapse file tree 3 files changed +17
-4
lines changed Original file line number Diff line number Diff line change 99from monai .inferers import sliding_window_inference
1010from monai .transforms import (
1111 AddChannel ,
12- AsDiscrete ,
12+ # AsDiscrete,
1313 Compose ,
1414 EnsureChannelFirstd ,
1515 EnsureType ,
3535 QuantileNormalization ,
3636 QuantileNormalizationd ,
3737 RemapTensor ,
38+ Threshold ,
3839 WeightsDownloader ,
3940)
4041
@@ -727,7 +728,8 @@ def inference(self):
727728 post_process_transforms = Compose (
728729 [
729730 RemapTensor (new_max = 1.0 , new_min = 0.0 ),
730- AsDiscrete (threshold = t ),
731+ # AsDiscrete(threshold=t),
732+ Threshold (threshold = t ),
731733 EnsureType (),
732734 ]
733735 )
Original file line number Diff line number Diff line change 1717from monai .inferers import sliding_window_inference
1818from monai .metrics import DiceMetric
1919from monai .transforms import (
20- AsDiscrete ,
20+ # AsDiscrete,
2121 Compose ,
2222 EnsureChannelFirstd ,
2323 EnsureType ,
4545 LogSignal ,
4646 QuantileNormalizationd ,
4747 RemapTensor ,
48+ Threshold ,
4849 TrainingReport ,
4950 WeightsDownloader ,
5051)
@@ -638,10 +639,11 @@ def get_loader_func(num_samples):
638639 labs = decollate_batch (val_labels )
639640
640641 # TODO : more parameters/flexibility
642+
641643 post_pred = Compose (
642644 [
643645 RemapTensor (new_max = 1 , new_min = 0 ),
644- AsDiscrete (threshold = 0.25 ), # needed ?
646+ Threshold (threshold = 0.5 ),
645647 EnsureType (),
646648 ]
647649 ) #
Original file line number Diff line number Diff line change @@ -213,6 +213,15 @@ def __call__(self, img):
213213 return utils .remap_image (img , new_max = self .max , new_min = self .min )
214214
215215
216+ class Threshold (Transform ):
217+ def __init__ (self , threshold = 0.5 ):
218+ super ().__init__ ()
219+ self .threshold = threshold
220+
221+ def __call__ (self , img ):
222+ return torch .where (img > self .threshold , 1 , 0 )
223+
224+
216225@dataclass
217226class InferenceResult :
218227 """Class to record results of a segmentation job"""
You can’t perform that action at this time.
0 commit comments