88from monai .data import DataLoader , Dataset
99from monai .inferers import sliding_window_inference
1010from monai .transforms import (
11- AddChannel ,
11+ # AddChannel,
1212 # AsDiscrete,
1313 Compose ,
1414 EnsureChannelFirstd ,
@@ -231,9 +231,8 @@ def load_layer(self):
231231 f"Data array is not 3-dimensional but { volume_dims } -dimensional,"
232232 f" please check for extra channel/batch dimensions"
233233 )
234- volume = np .swapaxes (
235- volume , 0 , 2
236- ) # for dims to be monai-like, i.e. xyz, from napari zyx
234+ volume = utils .correct_rotation (volume )
235+ # volume = np.reshape(volume, newshape=(1, 1, *volume.shape))
237236
238237 dims_check = volume .shape
239238
@@ -252,9 +251,9 @@ def load_layer(self):
252251 normalization ,
253252 ToTensor (),
254253 # anisotropic_transform,
255- AddChannel (),
254+ # AddChannel(),
256255 # SpatialPad(spatial_size=pad),
257- AddChannel (),
256+ # AddChannel(),
258257 EnsureType (),
259258 ],
260259 map_items = False ,
@@ -269,9 +268,9 @@ def load_layer(self):
269268 normalization ,
270269 ToTensor (),
271270 # anisotropic_transform,
272- AddChannel (),
271+ # AddChannel(),
273272 SpatialPad (spatial_size = pad ),
274- AddChannel (),
273+ # AddChannel(),
275274 EnsureType (),
276275 ],
277276 map_items = False ,
@@ -541,6 +540,14 @@ def run_crf(self, image, labels, aniso_transform, image_id=0):
541540 try :
542541 if aniso_transform is not None :
543542 image = aniso_transform (image )
543+
544+ if image .shape [- 3 :] != labels .shape [- 3 :]:
545+ image = utils .correct_rotation (image )
546+ if image .shape [- 3 :] != labels .shape [- 3 :]:
547+ logger .warning (
548+ f"Labels shape mismatch: target { image .shape } , got { labels .shape } . CRF will likely fail."
549+ )
550+
544551 crf_results = crf_with_config (
545552 image , labels , config = self .config .crf_config , log = self .log
546553 )
@@ -572,13 +579,24 @@ def stats_csv(self, instance_labels):
572579 def inference_on_layer (self , image , model , post_process_transforms ):
573580 self .log ("-" * 10 )
574581 self .log ("Inference started on layer..." )
575-
582+ image = image .view ((1 , 1 , * image .shape ))
583+ logger .debug (f"Layer shape @ inference input: { image .shape } " )
576584 out = self .model_output (
577585 image ,
578586 model ,
579587 post_process_transforms ,
580588 aniso_transform = self .aniso_transform ,
581589 )
590+ logger .debug (f"Inference on layer result shape : { out .shape } " )
591+ out = utils .correct_rotation (out )
592+ extra_dims = len (image .shape ) - 3
593+ layer_shape_corrected = np .swapaxes (
594+ image , extra_dims , 2 + extra_dims
595+ ).shape
596+ if out .shape [- 3 :] != layer_shape_corrected [- 3 :]:
597+ logger .debug (
598+ f"Output shape { out .shape [- 3 :]} does not match input shape { layer_shape_corrected [- 3 :]} on HWD dims even after rotation"
599+ )
582600 self .save_image (out , from_layer = True )
583601
584602 instance_labels , stats = self .get_instance_result (
0 commit comments