Skip to content

Commit ccd422a

Browse files
committed
Remove deprecated call to AddChannel
1 parent c578bd9 commit ccd422a

File tree

1 file changed

+27
-9
lines changed

1 file changed

+27
-9
lines changed

napari_cellseg3d/code_models/worker_inference.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from monai.data import DataLoader, Dataset
99
from monai.inferers import sliding_window_inference
1010
from monai.transforms import (
11-
AddChannel,
11+
# AddChannel,
1212
# AsDiscrete,
1313
Compose,
1414
EnsureChannelFirstd,
@@ -231,9 +231,8 @@ def load_layer(self):
231231
f"Data array is not 3-dimensional but {volume_dims}-dimensional,"
232232
f" please check for extra channel/batch dimensions"
233233
)
234-
volume = np.swapaxes(
235-
volume, 0, 2
236-
) # for dims to be monai-like, i.e. xyz, from napari zyx
234+
volume = utils.correct_rotation(volume)
235+
# volume = np.reshape(volume, newshape=(1, 1, *volume.shape))
237236

238237
dims_check = volume.shape
239238

@@ -252,9 +251,9 @@ def load_layer(self):
252251
normalization,
253252
ToTensor(),
254253
# anisotropic_transform,
255-
AddChannel(),
254+
# AddChannel(),
256255
# SpatialPad(spatial_size=pad),
257-
AddChannel(),
256+
# AddChannel(),
258257
EnsureType(),
259258
],
260259
map_items=False,
@@ -269,9 +268,9 @@ def load_layer(self):
269268
normalization,
270269
ToTensor(),
271270
# anisotropic_transform,
272-
AddChannel(),
271+
# AddChannel(),
273272
SpatialPad(spatial_size=pad),
274-
AddChannel(),
273+
# AddChannel(),
275274
EnsureType(),
276275
],
277276
map_items=False,
@@ -541,6 +540,14 @@ def run_crf(self, image, labels, aniso_transform, image_id=0):
541540
try:
542541
if aniso_transform is not None:
543542
image = aniso_transform(image)
543+
544+
if image.shape[-3:] != labels.shape[-3:]:
545+
image = utils.correct_rotation(image)
546+
if image.shape[-3:] != labels.shape[-3:]:
547+
logger.warning(
548+
f"Labels shape mismatch: target {image.shape}, got {labels.shape}. CRF will likely fail."
549+
)
550+
544551
crf_results = crf_with_config(
545552
image, labels, config=self.config.crf_config, log=self.log
546553
)
@@ -572,13 +579,24 @@ def stats_csv(self, instance_labels):
572579
def inference_on_layer(self, image, model, post_process_transforms):
573580
self.log("-" * 10)
574581
self.log("Inference started on layer...")
575-
582+
image = image.view((1, 1, *image.shape))
583+
logger.debug(f"Layer shape @ inference input: {image.shape}")
576584
out = self.model_output(
577585
image,
578586
model,
579587
post_process_transforms,
580588
aniso_transform=self.aniso_transform,
581589
)
590+
logger.debug(f"Inference on layer result shape : {out.shape}")
591+
out = utils.correct_rotation(out)
592+
extra_dims = len(image.shape) - 3
593+
layer_shape_corrected = np.swapaxes(
594+
image, extra_dims, 2 + extra_dims
595+
).shape
596+
if out.shape[-3:] != layer_shape_corrected[-3:]:
597+
logger.debug(
598+
f"Output shape {out.shape[-3:]} does not match input shape {layer_shape_corrected[-3:]} on HWD dims even after rotation"
599+
)
582600
self.save_image(out, from_layer=True)
583601

584602
instance_labels, stats = self.get_instance_result(

0 commit comments

Comments
 (0)