1717from monai .transforms import AddChannel
1818from monai .transforms import AsDiscrete
1919from monai .transforms import Compose
20- from monai .transforms import EnsureChannelFirst
2120from monai .transforms import EnsureChannelFirstd
2221from monai .transforms import EnsureType
2322from monai .transforms import EnsureTyped
@@ -246,33 +245,32 @@ def load_folder(self):
246245 def load_layer (self ):
247246
248247 volume = np .array (self .layer .data , dtype = np .int16 )
248+ volume = np .swapaxes (
249+ volume , 0 , 2
250+ ) # for anisotropy to be monai-like, i.e. zyx
249251 print ("Loading layer" )
250- print (volume .shape )
251-
252252 dims_check = volume .shape
253253 self .log ("\n Checking dimensions..." )
254254 pad = utils .get_padding_dim (dims_check )
255- print (f"padding: { pad } " )
255+ print (volume .shape )
256+ print (volume .dtype )
256257 load_transforms = Compose (
257258 [
258- AddChannel (),
259- AddChannel (),
260259 ToTensor (),
261- EnsureType (),
262- # EnsureChannelFirst(),
263- # Orientationd(keys=["image"], axcodes="PLI"),
264260 # anisotropic_transform,
261+ AddChannel (),
265262 SpatialPad (spatial_size = pad ),
266- ]
263+ AddChannel (),
264+ EnsureType (),
265+ ],
266+ map_items = False ,
267+ log_stats = True ,
267268 )
268269
269270 self .log ("\n Loading dataset..." )
270- inference_ds = Dataset (data = volume , transform = load_transforms )
271- inference_loader = DataLoader (
272- inference_ds , batch_size = 1 , num_workers = 2
273- )
271+ input_image = load_transforms (volume )
274272 self .log ("Done" )
275- return inference_loader
273+ return input_image
276274
277275 def model_output (self , inputs , model , post_process_transforms ):
278276
@@ -357,12 +355,7 @@ def method(image):
357355 self .log (os .path .split (instance_filepath )[1 ])
358356
359357 # print(self.stats_to_csv)
360- if self .stats_to_csv :
361- data_dict = volume_stats (
362- instance_labels
363- ) # TODO test with area mesh function
364- return data_dict
365- return None
358+ return instance_labels
366359
367360 def inference_on_list (self , inf_data , i , model , post_process_transforms ):
368361
@@ -538,16 +531,13 @@ def inference_on_list(self, inf_data, i, model, post_process_transforms):
538531 # print(result)
539532 return result
540533
541- def inference_on_array (self , image , model , post_process_transforms ):
534+ def inference_on_layer (self , image , model , post_process_transforms ):
542535
543536 self .log ("-" * 10 )
544537 self .log (f"Inference started on layer..." )
545538
546539 # print(inputs.shape)
547-
548- image = ToTensor ()(image )
549- image = AddChannel ()(image )
550- image = AddChannel ()(image )
540+ image = image .type (torch .FloatTensor )
551541 inputs = image .to ("cpu" )
552542
553543 model_output = lambda inputs : post_process_transforms (
@@ -632,7 +622,13 @@ def inference_on_array(self, image, model, post_process_transforms):
632622 #################
633623 #################
634624 if self .instance_params ["do_instance" ]:
635- data_dict = self .instance_seg (to_instance )
625+ instance_labels = self .instance_seg (to_instance )
626+ if self .stats_to_csv :
627+ data_dict = volume_stats (
628+ instance_labels
629+ ) # TODO test with area mesh function
630+ else :
631+ data_dict = None
636632 # self.log(
637633 # f"\nRunning instance segmentation for image n°{image_id}"
638634 # )
@@ -691,16 +687,14 @@ def inference_on_array(self, image, model, post_process_transforms):
691687 instance_labels = None
692688 data_dict = None
693689
694- # logging(f"Inference completed on image {i+1}")
695690 result = {
696691 "image_id" : 0 ,
697692 "original" : None ,
698- "instance_labels" : instance_labels ,
693+ "instance_labels" : np . swapaxes ( instance_labels , 0 , 2 ) ,
699694 "object stats" : data_dict ,
700- "result" : out ,
695+ "result" : np . swapaxes ( out , 0 , 2 ) ,
701696 "model_name" : self .model_dict ["name" ],
702697 }
703- # print(result)
704698 return result
705699
706700 def inference (self ):
@@ -838,8 +832,6 @@ def inference(self):
838832 )
839833 elif is_folder :
840834 inference_loader = self .load_folder ()
841- elif is_layer :
842- inference_loader = self .load_layer ()
843835 ##################
844836 ##################
845837 # DEBUG
@@ -850,6 +842,10 @@ def inference(self):
850842 print (image .shape )
851843 ##################
852844 ##################
845+ elif is_layer :
846+ input_image = self .load_layer ()
847+ print (input_image .shape )
848+
853849 else :
854850 raise ValueError ("No data has been provided. Aborting." )
855851
@@ -861,13 +857,11 @@ def inference(self):
861857 if is_folder :
862858 for i , inf_data in enumerate (inference_loader ):
863859 yield self .inference_on_list (
864- inf_data ,i , model , post_process_transforms
860+ inf_data , i , model , post_process_transforms
865861 )
866862 elif is_layer :
867- image = self .layer .data
868- print (image .shape )
869- yield self .inference_on_array (
870- image , model , post_process_transforms
863+ yield self .inference_on_layer (
864+ input_image , model , post_process_transforms
871865 )
872866 # for i, inf_data in enumerate(inference_loader):
873867 #
0 commit comments