Skip to content

Commit 30d3336

Browse files
committed
Dice update
1 parent 2eb0624 commit 30d3336

File tree

1 file changed

+3
-5
lines changed

1 file changed

+3
-5
lines changed

napari_cellseg3d/code_models/worker_training.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -484,9 +484,7 @@ def get_loader_func(num_samples):
484484
patience=self.config.scheduler_patience,
485485
verbose=VERBOSE_SCHEDULER,
486486
)
487-
dice_metric = DiceMetric(
488-
include_background=False, reduction="mean"
489-
)
487+
dice_metric = DiceMetric(include_background=True, reduction="mean")
490488

491489
best_metric = -1
492490
best_metric_epoch = -1
@@ -664,8 +662,8 @@ def get_loader_func(num_samples):
664662
)
665663
checkpoint_output = [
666664
item.numpy()
667-
for batch in checkpoint_output
668-
for item in batch
665+
for channel in checkpoint_output
666+
for item in channel
669667
]
670668
checkpoint_output[2] = checkpoint_output[2].astype(
671669
np.uint16

0 commit comments

Comments
 (0)