Skip to content

Commit 6605081

Browse files
committed
Reintroduced best Dice channel seeking + refacto
1 parent b4b86f8 commit 6605081

File tree

2 files changed

+34
-27
lines changed

2 files changed

+34
-27
lines changed

napari_cellseg3d/code_models/worker_training.py

Lines changed: 11 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -779,33 +779,17 @@ def eval(self, model, epoch) -> TrainingReport:
779779
f"Val decoder outputs shape: {val_decoder_outputs.shape}"
780780
)
781781

782-
# dices = []
783-
# Find in which channel the labels are (avoid background)
784-
# for channel in range(val_outputs.shape[1]):
785-
# dices.append(
786-
# utils.dice_coeff(
787-
# y_pred=val_outputs[
788-
# 0, channel : (channel + 1), :, :, :
789-
# ],
790-
# y_true=val_labels[0],
791-
# )
792-
# )
793-
# logger.debug(f"DICE COEFF: {dices}")
794-
# max_dice_channel = torch.argmax(
795-
# torch.Tensor(dices)
796-
# )
797-
# logger.debug(
798-
# f"MAX DICE CHANNEL: {max_dice_channel}"
799-
# )
782+
max_dice_channel = utils.seek_best_dice_coeff_channel(
783+
y_pred=val_outputs, y_true=val_labels
784+
)
800785
self.dice_metric(
801-
y_pred=val_outputs,
802-
# [
803-
# :,
804-
# max_dice_channel : (max_dice_channel + 1),
805-
# :,
806-
# :,
807-
# :,
808-
# ],
786+
y_pred=val_outputs[
787+
:,
788+
max_dice_channel : (max_dice_channel + 1),
789+
:,
790+
:,
791+
:,
792+
],
809793
y=val_labels,
810794
)
811795

@@ -1282,7 +1266,7 @@ def get_patch_loader_func(num_samples):
12821266
batch_size=self.config.batch_size,
12831267
num_workers=self.config.num_workers,
12841268
)
1285-
logger.info("\nDone")
1269+
logger.debug("\nDone")
12861270

12871271
logger.debug("Optimizer")
12881272
optimizer = (

napari_cellseg3d/utils.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,29 @@ def dice_coeff(
229229
)
230230

231231

232+
def seek_best_dice_coeff_channel(y_pred, y_true) -> torch.Tensor:
233+
"""Compute Dice-Sorensen coefficient between unsupervised model output and ground truth labels;
234+
returns the channel with the highest dice coefficient.
235+
Args:
236+
y_true: Ground truth label
237+
y_pred: Prediction label
238+
Returns: best Dice coefficient channel
239+
"""
240+
dices = []
241+
# Find in which channel the labels are (to avoid background)
242+
for channel in range(y_pred.shape[1]):
243+
dices.append(
244+
dice_coeff(
245+
y_pred=y_pred[0, channel : (channel + 1), :, :, :],
246+
y_true=y_true[0],
247+
)
248+
)
249+
LOGGER.debug(f"DICE COEFF: {dices}")
250+
max_dice_channel = torch.argmax(torch.Tensor(dices))
251+
LOGGER.debug(f"MAX DICE CHANNEL: {max_dice_channel}")
252+
return max_dice_channel
253+
254+
232255
def correct_rotation(image):
233256
"""Rotates the exes 0 and 2 in [DHW] section of image array"""
234257
extra_dims = len(image.shape) - 3

0 commit comments

Comments
 (0)