Skip to content

Commit 10e5991

Browse files
committed
Enforce tensor range [0,1]
1 parent da1f149 commit 10e5991

File tree

3 files changed

+29
-5
lines changed

3 files changed

+29
-5
lines changed

napari_cellseg3d/code_models/worker_inference.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
ONNXModelWrapper,
3535
QuantileNormalization,
3636
QuantileNormalizationd,
37+
RemapTensor,
3738
WeightsDownloader,
3839
)
3940

@@ -715,11 +716,20 @@ def inference(self):
715716
# )
716717

717718
if not post_process_config.thresholding.enabled:
718-
post_process_transforms = EnsureType()
719+
post_process_transforms = Compose(
720+
[
721+
RemapTensor(new_max=1.0, new_min=0.0),
722+
EnsureType(),
723+
]
724+
)
719725
else:
720726
t = post_process_config.thresholding.threshold_value
721727
post_process_transforms = Compose(
722-
AsDiscrete(threshold=t), EnsureType()
728+
[
729+
RemapTensor(new_max=1.0, new_min=0.0),
730+
AsDiscrete(threshold=t),
731+
EnsureType(),
732+
]
723733
)
724734

725735
is_folder = self.config.images_filepaths is not None

napari_cellseg3d/code_models/worker_training.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
PRETRAINED_WEIGHTS_DIR,
4545
LogSignal,
4646
QuantileNormalizationd,
47+
RemapTensor,
4748
TrainingReport,
4849
WeightsDownloader,
4950
)
@@ -638,12 +639,15 @@ def get_loader_func(num_samples):
638639

639640
# TODO : more parameters/flexibility
640641
post_pred = Compose(
641-
AsDiscrete(threshold=0.5), # needed ?
642-
EnsureType(),
642+
[
643+
RemapTensor(new_max=1, new_min=0),
644+
AsDiscrete(threshold=0.5), # needed ?
645+
EnsureType(),
646+
]
643647
) #
644648
post_label = EnsureType()
645649

646-
output_raw = [t for t in pred]
650+
output_raw = [RemapTensor(0, 1)(t) for t in pred]
647651
val_outputs = [
648652
post_pred(res_tensor) for res_tensor in pred
649653
]

napari_cellseg3d/code_models/workers_utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,16 @@ def __call__(self, img):
203203
return utils.quantile_normalization(img)
204204

205205

206+
class RemapTensor(Transform):
207+
def __init__(self, new_max, new_min):
208+
super().__init__()
209+
self.max = new_max
210+
self.min = new_min
211+
212+
def __call__(self, img):
213+
return utils.remap_image(img, new_max=self.max, new_min=self.min)
214+
215+
206216
@dataclass
207217
class InferenceResult:
208218
"""Class to record results of a segmentation job"""

0 commit comments

Comments
 (0)