@@ -311,9 +311,8 @@ def load_folder(self):
311311
312312 # TODO : better solution than loading first image always ?
313313 data_check = LoadImaged (keys = ["image" ])(images_dict [0 ])
314- # print(data)
314+
315315 check = data_check ["image" ].shape
316- # print(check)
317316 # TODO remove
318317 # z_aniso = 5 / 1.5
319318 # if zoom is not None :
@@ -384,7 +383,14 @@ def load_layer(self):
384383 self .log ("Done" )
385384 return input_image
386385
387- def model_output (self , inputs , model , post_process_transforms ):
386+ def model_output (
387+ self ,
388+ inputs ,
389+ model ,
390+ post_process_transforms ,
391+ post_process = True ,
392+ aniso_transform = None ,
393+ ):
388394
389395 inputs = inputs .to ("cpu" )
390396
@@ -412,7 +418,103 @@ def model_output(self, inputs, model, post_process_transforms):
412418 )
413419
414420 out = outputs .detach ().cpu ()
415- return out
421+
422+ if aniso_transform is not None :
423+ out = aniso_transform (out )
424+
425+ if post_process :
426+ out = np .array (out ).astype (np .float32 )
427+ out = np .squeeze (out )
428+ return out
429+ else :
430+ return out
431+
432+ def create_result_dict (
433+ self ,
434+ semantic_labels ,
435+ instance_labels ,
436+ from_layer : bool ,
437+ original = None ,
438+ data_dict = None ,
439+ i = 0 ,
440+ ):
441+
442+ if not from_layer and original is None :
443+ raise ValueError (
444+ "If the image is not from a layer, an original should always be available"
445+ )
446+
447+ if from_layer :
448+ if i != 0 :
449+ raise ValueError (
450+ "A layer's ID should always be 0 (default value)"
451+ )
452+ semantic_labels = np .swapaxes (semantic_labels , 0 , 2 )
453+
454+ return {
455+ "image_id" : i + 1 ,
456+ "original" : original ,
457+ "instance_labels" : instance_labels ,
458+ "object stats" : data_dict ,
459+ "result" : semantic_labels ,
460+ "model_name" : self .model_dict ["name" ],
461+ }
462+
463+ def get_original_filename (self , i ):
464+ return os .path .basename (self .images_filepaths [i ]).split ("." )[0 ]
465+
466+ def get_instance_result (self , semantic_labels , from_layer = False , i = - 1 ):
467+
468+ if not from_layer and i == - 1 :
469+ raise ValueError (
470+ "An ID should be provided when running from a file"
471+ )
472+
473+ if self .instance_params ["do_instance" ]:
474+ instance_labels = self .instance_seg (
475+ semantic_labels ,
476+ i + 1 ,
477+ )
478+ if from_layer :
479+ instance_labels = np .swapaxes (instance_labels , 0 , 2 )
480+ data_dict = self .stats_csv (instance_labels )
481+ else :
482+ instance_labels = None
483+ data_dict = None
484+ return instance_labels , data_dict
485+
486+ def save_image (
487+ self ,
488+ image ,
489+ from_layer = False ,
490+ i = 0 ,
491+ ):
492+
493+ if not from_layer :
494+ original_filename = "_" + self .get_original_filename (i ) + "_"
495+ else :
496+ original_filename = "_"
497+
498+ time = utils .get_date_time ()
499+
500+ file_path = (
501+ self .results_path
502+ + "/"
503+ + f"Prediction_{ i + 1 } "
504+ + original_filename
505+ + self .model_dict ["name" ]
506+ + f"_{ time } _"
507+ + self .filetype
508+ )
509+
510+ imwrite (file_path , image )
511+
512+ if from_layer :
513+ self .log (f"\n Layer prediction saved as :" )
514+ else :
515+ self .log (f"\n File n°{ i + 1 } saved as :" )
516+ filename = os .path .split (file_path )[1 ]
517+ self .log (filename )
416518
417519 def aniso_transform (self , image ):
418520 zoom = self .transforms ["zoom" ][1 ]
@@ -468,71 +570,35 @@ def method(image):
468570 return instance_labels
469571 # print(self.stats_to_csv)
470572
471- def inference_on_list (self , inf_data , i , model , post_process_transforms ):
573+ def inference_on_folder (self , inf_data , i , model , post_process_transforms ):
472574
473575 self .log ("-" * 10 )
474576 self .log (f"Inference started on image n°{ i + 1 } ..." )
475577
476578 inputs = inf_data ["image" ]
477- # print(inputs.shape)
478-
479- out = self .model_output (inputs , model , post_process_transforms )
480- out = self .aniso_transform (out )
481579
482- out = np .array (out ).astype (np .float32 )
483- out = np .squeeze (out )
484- to_instance = out # avoid post processing since thresholding is done there anyway
485- image_id = i + 1
486- time = utils .get_date_time ()
487- # print(time)
488-
489- original_filename = os .path .basename (self .images_filepaths [i ]).split (
490- "."
491- )[0 ]
492-
493- # File output save name : original-name_model_date+time_number.filetype
494- file_path = (
495- self .results_path
496- + "/"
497- + f"Prediction_{ image_id } _"
498- + original_filename
499- + "_"
500- + self .model_dict ["name" ]
501- + f"_{ time } _"
502- + self .filetype
580+ out = self .model_output (
581+ inputs ,
582+ model ,
583+ post_process_transforms ,
584+ aniso_transform = self .aniso_transform ,
503585 )
504586
505- # print(filename)
506- imwrite (file_path , out )
507-
508- self .log (f"\n File n°{ image_id } saved as :" )
509- filename = os .path .split (file_path )[1 ]
510- self .log (filename )
511-
512- #################
513- #################
514- #################
515- if self .instance_params ["do_instance" ]:
516- instance_labels = self .instance_seg (
517- to_instance , image_id , original_filename
518- )
519- data_dict = self .stats_csv (instance_labels )
520- else :
521- instance_labels = None
522- data_dict = None
587+ self .save_image (out , i = i )
588+ instance_labels , data_dict = self .get_instance_result (out , i = i )
523589
524590 original = np .array (inf_data ["image" ]).astype (np .float32 )
525591
526592 self .log (f"Inference completed on layer" )
527- result = {
528- "image_id" : i + 1 ,
529- "original" : original ,
530- "instance_labels" : instance_labels ,
531- "object stats" : data_dict ,
532- "result" : out ,
533- "model_name" : self . model_dict [ "name" ] ,
534- }
535- return result
593+
594+ return self . create_result_dict (
595+ out ,
596+ instance_labels ,
597+ from_layer = False ,
598+ original = original ,
599+ data_dict = data_dict ,
600+ i = i ,
601+ )
536602
537603 def stats_csv (self , instance_labels ):
538604 if self .stats_to_csv :
@@ -554,74 +620,20 @@ def inference_on_layer(self, image, model, post_process_transforms):
554620 self .log ("-" * 10 )
555621 self .log (f"Inference started on layer..." )
556622
557- # print(inputs.shape)
558623 image = image .type (torch .FloatTensor )
559- out = self .model_output (image , model , post_process_transforms )
560- out = self .aniso_transform (out )
561- # if self.transforms["zoom"][0]:
562- # zoom = self.transforms["zoom"][1]
563- # anisotropic_transform = Zoom(
564- # zoom=zoom,
565- # keep_size=False,
566- # padding_mode="empty",
567- # )
568- # out = anisotropic_transform(out[0])
569- ##################
570- ##################
571- ##################
572- out = post_process_transforms (out )
573- out = np .array (out ).astype (np .float32 )
574- out = np .squeeze (out )
575- to_instance = out # avoid post processing since thresholding is done there anyway
576-
577- # batch_len = out.shape[1]
578- # print("trying to check len")
579- # print(batch_len)
580- # if batch_len != 1 :
581- # sum = np.sum(out, axis=1)
582- # print(sum.shape)
583- # out = sum
584- # print(out.shape)
585624
586- time = utils .get_date_time ()
587- # print(time)
588- # File output save name : original-name_model_date+time_number.filetype
589- file_path = os .path .join (
590- self .results_path ,
591- f"Prediction_layer"
592- + "_"
593- + self .model_dict ["name" ]
594- + f"_{ time } _"
595- + self .filetype ,
625+ out = self .model_output (
626+ image ,
627+ model ,
628+ post_process_transforms ,
629+ aniso_transform = self .aniso_transform ,
596630 )
597631
598- # print(filename)
599- imwrite (file_path , out )
600-
601- self .log (f"\n Layer prediction saved as :" )
602- filename = os .path .split (file_path )[1 ]
603- self .log (filename )
632+ self .save_image (out , from_layer = True )
604633
605- #################
606- #################
607- #################
608- if self .instance_params ["do_instance" ]:
609- instance_labels = self .instance_seg (to_instance )
610- instance_labels = np .swapaxes (instance_labels , 0 , 2 )
611- data_dict = self .stats_csv (instance_labels )
612- else :
613- instance_labels = None
614- data_dict = None
634+ instance_labels , data_dict = self .get_instance_result (out ,from_layer = True )
615635
616- result = {
617- "image_id" : 0 ,
618- "original" : None ,
619- "instance_labels" : instance_labels ,
620- "object stats" : data_dict ,
621- "result" : np .swapaxes (out , 0 , 2 ),
622- "model_name" : self .model_dict ["name" ],
623- }
624- return result
636+ return self .create_result_dict (out , instance_labels , from_layer = True , data_dict = data_dict )
625637
626638 def inference (self ):
627639 """
@@ -789,7 +801,7 @@ def inference(self):
789801 ################################
790802 if is_folder :
791803 for i , inf_data in enumerate (inference_loader ):
792- yield self .inference_on_list (
804+ yield self .inference_on_folder (
793805 inf_data , i , model , post_process_transforms
794806 )
795807 elif is_layer :
@@ -1328,7 +1340,7 @@ def train(self):
13281340 f"at epoch: { best_metric_epoch } "
13291341 )
13301342 model .to ("cpu" )
1331-
1343+
13321344 except Exception as e :
13331345 self .log (f"Error : { e } " )
13341346 self .quit ()
0 commit comments