@@ -319,27 +319,27 @@ def load_folder(self):
319319 self .log ("\n Checking dimensions..." )
320320 pad = utils .get_padding_dim (check )
321321
322- dims = self .model_dict ["model_input_size" ]
323-
324- if self .model_dict ["name" ] == "SegResNet" :
325- model = self .model_dict ["class" ].get_net (
326- input_image_size = [
327- dims ,
328- dims ,
329- dims ,
330- ]
331- )
332- elif self .model_dict ["name" ] == "SwinUNetR" :
333- model = self .model_dict ["class" ].get_net (
334- img_size = [dims , dims , dims ],
335- use_checkpoint = False ,
336- )
337- else :
338- model = self .model_dict ["class" ].get_net ()
339-
340- self .log_parameters ()
341-
342- model .to (self .device )
322+ # dims = self.model_dict["model_input_size"]
323+ #
324+ # if self.model_dict["name"] == "SegResNet":
325+ # model = self.model_dict["class"].get_net(
326+ # input_image_size=[
327+ # dims,
328+ # dims,
329+ # dims,
330+ # ]
331+ # )
332+ # elif self.model_dict["name"] == "SwinUNetR":
333+ # model = self.model_dict["class"].get_net(
334+ # img_size=[dims, dims, dims],
335+ # use_checkpoint=False,
336+ # )
337+ # else:
338+ # model = self.model_dict["class"].get_net()
339+ #
340+ # self.log_parameters()
341+ #
342+ # model.to(self.device)
343343
344344 # print("FILEPATHS PRINT")
345345 # print(self.images_filepaths)
@@ -399,18 +399,18 @@ def load_layer(self):
399399 # print(volume.shape)
400400 # print(volume.dtype)
401401 if self .use_window :
402- load_transforms = Compose (
403- [
404- ToTensor (),
405- # anisotropic_transform,
406- AddChannel (),
407- # SpatialPad(spatial_size=pad),
408- AddChannel (),
409- EnsureType (),
410- ],
411- map_items = False ,
412- log_stats = True ,
413- )
402+ load_transforms = Compose (
403+ [
404+ ToTensor (),
405+ # anisotropic_transform,
406+ AddChannel (),
407+ # SpatialPad(spatial_size=pad),
408+ AddChannel (),
409+ EnsureType (),
410+ ],
411+ map_items = False ,
412+ log_stats = True ,
413+ )
414414 else :
415415 load_transforms = Compose (
416416 [
@@ -558,13 +558,12 @@ def save_image(
558558 )
559559
560560 imwrite (file_path , image )
561+ filename = os .path .split (file_path )[1 ]
561562
562563 if from_layer :
563- self .log (f"\n Layer prediction saved as :" )
564+ self .log (f"\n Layer prediction saved as : { filename } " )
564565 else :
565- self .log (f"\n File n°{ i + 1 } saved as :" )
566- filename = os .path .split (file_path )[1 ]
567- self .log (filename )
566+ self .log (f"\n File n°{ i + 1 } saved as : { filename } " )
568567
569568 def aniso_transform (self , image ):
570569 zoom = self .transforms ["zoom" ][1 ]
@@ -680,9 +679,13 @@ def inference_on_layer(self, image, model, post_process_transforms):
680679
681680 self .save_image (out , from_layer = True )
682681
683- instance_labels , data_dict = self .get_instance_result (out ,from_layer = True )
682+ instance_labels , data_dict = self .get_instance_result (
683+ out , from_layer = True
684+ )
684685
685- return self .create_result_dict (out , instance_labels , from_layer = True , data_dict = data_dict )
686+ return self .create_result_dict (
687+ out , instance_labels , from_layer = True , data_dict = data_dict
688+ )
686689
687690 def inference (self ):
688691 """
@@ -724,20 +727,18 @@ def inference(self):
724727 torch .set_num_threads (1 ) # required for threading on macOS ?
725728 self .log ("Number of threads has been set to 1 for macOS" )
726729
727-
728730 try :
729731 dims = self .model_dict ["model_input_size" ]
730-
732+ self .log (f"MODEL DIMS : { dims } " )
733+ self .log (self .model_dict ["name" ])
731734
732735 if self .model_dict ["name" ] == "SegResNet" :
733- model = self .model_dict ["class" ].get_net ()(
736+ model = self .model_dict ["class" ].get_net (
734737 input_image_size = [
735738 dims ,
736739 dims ,
737740 dims ,
738741 ], # TODO FIX ! find a better way & remove model-specific code
739- out_channels = 1 ,
740- # dropout_prob=0.3,
741742 )
742743 elif self .model_dict ["name" ] == "SwinUNetR" :
743744 model = self .model_dict ["class" ].get_net (
@@ -772,10 +773,7 @@ def inference(self):
772773 AsDiscrete (threshold = t ), EnsureType ()
773774 )
774775
775-
776- self .log (
777- "\n Loading weights..."
778- )
776+ self .log ("\n Loading weights..." )
779777
780778 if self .weights_dict ["custom" ]:
781779 weights = self .weights_dict ["path" ]
@@ -1091,9 +1089,6 @@ def train(self):
10911089 model = model_class .get_net () # get an instance of the model
10921090 model = model .to (self .device )
10931091
1094-
1095-
1096-
10971092 epoch_loss_values = []
10981093 val_metric_values = []
10991094
0 commit comments