Skip to content

Commit 716fc1e

Browse files
committed
Fix training volumes
1 parent 71e4c9d commit 716fc1e

File tree

1 file changed

+11
-8
lines changed

1 file changed

+11
-8
lines changed

napari_cellseg3d/code_models/worker_training.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -657,16 +657,19 @@ def get_loader_func(num_samples):
657657

658658
checkpoint_output.append(
659659
[
660-
val_outputs[0].detach().cpu().numpy(),
661-
val_inputs[0].detach().cpu().numpy(),
662-
val_labels[0]
663-
.detach()
664-
.cpu()
665-
.numpy()
666-
.astype(np.uint16),
660+
val_outputs[0].detach().cpu(),
661+
val_inputs[0].detach().cpu(),
662+
val_labels[0].detach().cpu(),
667663
]
668664
)
669-
# [np.squeeze(vol) for vol in checkpoint_output]
665+
checkpoint_output = [
666+
item.numpy()
667+
for batch in checkpoint_output
668+
for item in batch
669+
]
670+
checkpoint_output[2] = checkpoint_output[2].astype(
671+
np.uint16
672+
)
670673

671674
metric = dice_metric.aggregate().detach().item()
672675
dice_metric.reset()

0 commit comments

Comments
 (0)