2222from monai .transforms import AddChannel
2323from monai .transforms import AsDiscrete
2424from monai .transforms import Compose
25- from monai .transforms import EnsureChannelFirst
2625from monai .transforms import EnsureChannelFirstd
2726from monai .transforms import EnsureType
2827from monai .transforms import EnsureTyped
@@ -346,33 +345,32 @@ def load_folder(self):
346345 def load_layer (self ):
347346
348347 volume = np .array (self .layer .data , dtype = np .int16 )
348+ volume = np .swapaxes (
349+ volume , 0 , 2
350+ ) # for anisotropy to be monai-like, i.e. zyx
349351 print ("Loading layer" )
350- print (volume .shape )
351-
352352 dims_check = volume .shape
353353 self .log ("\n Checking dimensions..." )
354354 pad = utils .get_padding_dim (dims_check )
355- print (f"padding: { pad } " )
355+ print (volume .shape )
356+ print (volume .dtype )
356357 load_transforms = Compose (
357358 [
358- AddChannel (),
359- AddChannel (),
360359 ToTensor (),
361- EnsureType (),
362- # EnsureChannelFirst(),
363- # Orientationd(keys=["image"], axcodes="PLI"),
364360 # anisotropic_transform,
361+ AddChannel (),
365362 SpatialPad (spatial_size = pad ),
366- ]
363+ AddChannel (),
364+ EnsureType (),
365+ ],
366+ map_items = False ,
367+ log_stats = True ,
367368 )
368369
369370 self .log ("\n Loading dataset..." )
370- inference_ds = Dataset (data = volume , transform = load_transforms )
371- inference_loader = DataLoader (
372- inference_ds , batch_size = 1 , num_workers = 2
373- )
371+ input_image = load_transforms (volume )
374372 self .log ("Done" )
375- return inference_loader
373+ return input_image
376374
377375 def model_output (self , inputs , model , post_process_transforms ):
378376
@@ -457,12 +455,7 @@ def method(image):
457455 self .log (os .path .split (instance_filepath )[1 ])
458456
459457 # print(self.stats_to_csv)
460- if self .stats_to_csv :
461- data_dict = volume_stats (
462- instance_labels
463- ) # TODO test with area mesh function
464- return data_dict
465- return None
458+ return instance_labels
466459
467460 def inference_on_list (self , inf_data , i , model , post_process_transforms ):
468461
@@ -638,16 +631,13 @@ def inference_on_list(self, inf_data, i, model, post_process_transforms):
638631 # print(result)
639632 return result
640633
641- def inference_on_array (self , image , model , post_process_transforms ):
634+ def inference_on_layer (self , image , model , post_process_transforms ):
642635
643636 self .log ("-" * 10 )
644637 self .log (f"Inference started on layer..." )
645638
646639 # print(inputs.shape)
647-
648- image = ToTensor ()(image )
649- image = AddChannel ()(image )
650- image = AddChannel ()(image )
640+ image = image .type (torch .FloatTensor )
651641 inputs = image .to ("cpu" )
652642
653643 model_output = lambda inputs : post_process_transforms (
@@ -732,7 +722,13 @@ def inference_on_array(self, image, model, post_process_transforms):
732722 #################
733723 #################
734724 if self .instance_params ["do_instance" ]:
735- data_dict = self .instance_seg (to_instance )
725+ instance_labels = self .instance_seg (to_instance )
726+ if self .stats_to_csv :
727+ data_dict = volume_stats (
728+ instance_labels
729+ ) # TODO test with area mesh function
730+ else :
731+ data_dict = None
736732 # self.log(
737733 # f"\nRunning instance segmentation for image n°{image_id}"
738734 # )
@@ -791,16 +787,14 @@ def inference_on_array(self, image, model, post_process_transforms):
791787 instance_labels = None
792788 data_dict = None
793789
794- # logging(f"Inference completed on image {i+1}")
795790 result = {
796791 "image_id" : 0 ,
797792 "original" : None ,
798- "instance_labels" : instance_labels ,
793+ "instance_labels" : np . swapaxes ( instance_labels , 0 , 2 ) ,
799794 "object stats" : data_dict ,
800- "result" : out ,
795+ "result" : np . swapaxes ( out , 0 , 2 ) ,
801796 "model_name" : self .model_dict ["name" ],
802797 }
803- # print(result)
804798 return result
805799
806800 def inference (self ):
@@ -945,8 +939,6 @@ def inference(self):
945939 )
946940 elif is_folder :
947941 inference_loader = self .load_folder ()
948- elif is_layer :
949- inference_loader = self .load_layer ()
950942 ##################
951943 ##################
952944 # DEBUG
@@ -957,6 +949,10 @@ def inference(self):
957949 print (image .shape )
958950 ##################
959951 ##################
952+ elif is_layer :
953+ input_image = self .load_layer ()
954+ print (input_image .shape )
955+
960956 else :
961957 raise ValueError ("No data has been provided. Aborting." )
962958
@@ -968,13 +964,11 @@ def inference(self):
968964 if is_folder :
969965 for i , inf_data in enumerate (inference_loader ):
970966 yield self .inference_on_list (
971- inf_data ,i , model , post_process_transforms
967+ inf_data , i , model , post_process_transforms
972968 )
973969 elif is_layer :
974- image = self .layer .data
975- print (image .shape )
976- yield self .inference_on_array (
977- image , model , post_process_transforms
970+ yield self .inference_on_layer (
971+ input_image , model , post_process_transforms
978972 )
979973 # for i, inf_data in enumerate(inference_loader):
980974 #
0 commit comments