Skip to content

Commit 3b60954

Browse files
committed
Add model output and labels to training view
1 parent c9aa0db commit 3b60954

File tree

2 files changed

+28
-10
lines changed

2 files changed

+28
-10
lines changed

napari_cellseg3d/code_models/worker_training.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@
4444
PRETRAINED_WEIGHTS_DIR,
4545
LogSignal,
4646
QuantileNormalizationd,
47-
RemapTensor,
4847
TrainingReport,
4948
WeightsDownloader,
5049
)
@@ -640,14 +639,16 @@ def get_loader_func(num_samples):
640639
# TODO : more parameters/flexibility
641640
post_pred = Compose(
642641
[
643-
RemapTensor(new_max=1, new_min=0),
642+
# RemapTensor(new_max=1, new_min=0),
644643
AsDiscrete(threshold=0.5), # needed ?
645644
EnsureType(),
646645
]
647646
) #
648647
post_label = EnsureType()
649648

650-
output_raw = [RemapTensor(0, 1)(t) for t in pred]
649+
# output_raw = [RemapTensor(0, 1)(t) for t in pred]
650+
output_raw = pred
651+
651652
val_outputs = [
652653
post_pred(res_tensor) for res_tensor in pred
653654
]
@@ -658,7 +659,7 @@ def get_loader_func(num_samples):
658659

659660
# logger.debug(len(val_outputs))
660661
# logger.debug(len(val_labels))
661-
dice_test = np.array(
662+
dice_test = np.array( # TODO(cyril): remove
662663
[
663664
utils.dice_coeff(i, j)
664665
for i, j in zip(val_outputs, val_labels)
@@ -673,6 +674,7 @@ def get_loader_func(num_samples):
673674
checkpoint_output.append(
674675
[
675676
output_raw[0].detach().cpu(),
677+
val_outputs[0].detach().cpu(),
676678
val_inputs[0].detach().cpu(),
677679
val_labels[0].detach().cpu(),
678680
]
@@ -682,7 +684,7 @@ def get_loader_func(num_samples):
682684
for channel in checkpoint_output
683685
for item in channel
684686
]
685-
checkpoint_output[2] = checkpoint_output[2].astype(
687+
checkpoint_output[3] = checkpoint_output[2].astype(
686688
np.uint16
687689
)
688690

napari_cellseg3d/code_plugins/plugin_model_training.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -996,13 +996,21 @@ def _display_results(self, images, names, complete_missing=False):
996996
layer_output = self._viewer.add_image(
997997
data=images[0], name=names[0], colormap="turbo"
998998
)
999+
layer_output_discrete = self._viewer.add_image(
1000+
data=images[1], name=names[1], colormap="bop orange"
1001+
)
9991002
layer_image = self._viewer.add_image(
1000-
data=images[1], name=names[1], colormap="inferno"
1003+
data=images[2], name=names[2], colormap="inferno"
10011004
)
10021005
layer_labels = self._viewer.add_labels(
1003-
data=images[2], name=names[2]
1006+
data=images[3], name=names[3]
10041007
)
1005-
self.result_layers += [layer_output, layer_image, layer_labels]
1008+
self.result_layers += [
1009+
layer_output,
1010+
layer_output_discrete,
1011+
layer_image,
1012+
layer_labels,
1013+
]
10061014
self._viewer.grid.enabled = True
10071015
self._viewer.dims.ndisplay = 3
10081016
self._viewer.reset_view()
@@ -1018,15 +1026,22 @@ def _display_results(self, images, names, complete_missing=False):
10181026
)
10191027
self.result_layers[0] = layer_output
10201028
elif i == 1:
1029+
layer_output_discrete = self._viewer.add_image(
1030+
data=images[i],
1031+
name=names[i],
1032+
colormap="bop orange",
1033+
)
1034+
self.result_layers[1] = layer_output_discrete
1035+
elif i == 2:
10211036
layer_image = self._viewer.add_image(
10221037
data=images[i], name=names[i], colormap="inferno"
10231038
)
1024-
self.result_layers[1] = layer_image
1039+
self.result_layers[2] = layer_image
10251040
else:
10261041
layer_labels = self._viewer.add_labels(
10271042
data=images[i], name=names[i]
10281043
)
1029-
self.result_layers[2] = layer_labels
1044+
self.result_layers[3] = layer_labels
10301045
self.result_layers[i].data = images[i]
10311046
self.result_layers[i].refresh()
10321047

@@ -1042,6 +1057,7 @@ def on_yield(self, report: TrainingReport):
10421057
try:
10431058
layer_names = [
10441059
"Validation output",
1060+
"Validation output (discrete)",
10451061
"Validation image",
10461062
"Validation labels",
10471063
]

0 commit comments

Comments
 (0)