@@ -297,17 +297,8 @@ def run():
297297 print ("Just Started" )
298298 sys .stdout .write ('Just started ' )
299299
300- class UNet (nets .UNet ):
301- def __init__ (self ,spatial_dims , in_channels , out_channels ,
302- channels ,strides ):
303- super ().__init__ (spatial_dims , in_channels , out_channels , channels , strides )
304-
305- def forward (self , ** kwargs ) -> torch .Tensor :
306- image = kwargs ["image" ].unsqueeze (0 ) # As shape of (1, 1, 128, 128, 128) bs, ch, h, w, d is expected....
307- return super ().forward (image )
308-
309300 # Loading Model 1
310- model = UNet (spatial_dims = 3 , in_channels = 1 , out_channels = 4 , channels = [16 ,32 ,64 ], strides = [2 ,2 ])
301+ model = nets . UNet (spatial_dims = 3 , in_channels = 1 , out_channels = 4 , channels = [16 ,32 ,64 ], strides = [2 ,2 ])
311302 model_file_state_dict = torch .load (join (RESOURCE_PATH , 'best.ckpt' ))['state_dict' ]
312303 pretrained_dict = {key .replace ("net." , "" ): value for key , value in model_file_state_dict .items ()}
313304 model .load_state_dict (pretrained_dict )
@@ -322,15 +313,15 @@ def forward(self, **kwargs) -> torch.Tensor:
322313
323314 pelvic_fracture_ct = load_image_file_after_transform (location = INPUT_PATH )
324315
325- logits = model .forward (** pelvic_fracture_ct )
316+ # print(pelvic_fracture_ct["image"].shape) #torch.Size([1, 128, 128, 128])
317+ logits = model .forward (pelvic_fracture_ct ["image" ].unsqueeze (0 )) # as shape needed by model is 1,1,128,128,128 bs,ch,h,w,d
326318 softmax_logits = nn .Softmax (dim = 1 )(logits )
327319 predicted_segmentation = torch .argmax (softmax_logits , 1 )
328- predicted_segmentation = predicted_segmentation .squeeze (dim = 0 )
329320
330321 print (np .unique (predicted_segmentation , return_counts = True ))
331322 # print(pelvic_fracture_ct["image"].shape) # (1, 128, 128, 128)
332- # print(predicted_segmentation.shape) # (128, 128, 128)
333- frac_LeftIliac_img , frac_sacrum_img , frac_RightIliac_img = saveDiffFrac (sitk .GetImageFromArray (pelvic_fracture_ct ["image" ][0 ]), sitk .GetImageFromArray (predicted_segmentation ))
323+ # print(predicted_segmentation.shape) # (1, 128, 128, 128)
324+ frac_LeftIliac_img , frac_sacrum_img , frac_RightIliac_img = saveDiffFrac (sitk .GetImageFromArray (pelvic_fracture_ct ["image" ][0 ]), sitk .GetImageFromArray (predicted_segmentation [ 0 ] ))
334325
335326 sys .stdout .write ('<----------Anatomical Model baseline unet completed---------------------> \n ' )
336327 # print('<----------Anatomical Model baseline unet completed--------------------->')
0 commit comments