Skip to content

Commit dede29f

Browse files
committed
removed the part of dataloader and dataset while predicting image
1 parent 1179cde commit dede29f

File tree

1 file changed

+5
-14
lines changed

1 file changed

+5
-14
lines changed

inference_docker.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -297,17 +297,8 @@ def run():
297297
print("Just Started")
298298
sys.stdout.write('Just started ')
299299

300-
class UNet(nets.UNet):
301-
def __init__(self,spatial_dims, in_channels, out_channels,
302-
channels,strides):
303-
super().__init__(spatial_dims, in_channels, out_channels, channels, strides)
304-
305-
def forward(self, **kwargs) -> torch.Tensor:
306-
image = kwargs["image"].unsqueeze(0) # As shape of (1, 1, 128, 128, 128) bs, ch, h, w, d is expected....
307-
return super().forward(image)
308-
309300
# Loading Model 1
310-
model = UNet(spatial_dims=3, in_channels=1, out_channels=4, channels=[16,32,64], strides=[2,2])
301+
model = nets.UNet(spatial_dims=3, in_channels=1, out_channels=4, channels=[16,32,64], strides=[2,2])
311302
model_file_state_dict = torch.load(join(RESOURCE_PATH, 'best.ckpt'))['state_dict']
312303
pretrained_dict = {key.replace("net.", ""): value for key, value in model_file_state_dict.items()}
313304
model.load_state_dict(pretrained_dict)
@@ -322,15 +313,15 @@ def forward(self, **kwargs) -> torch.Tensor:
322313

323314
pelvic_fracture_ct = load_image_file_after_transform(location=INPUT_PATH)
324315

325-
logits = model.forward(**pelvic_fracture_ct)
316+
# print(pelvic_fracture_ct["image"].shape) #torch.Size([1, 128, 128, 128])
317+
logits = model.forward(pelvic_fracture_ct["image"].unsqueeze(0)) # as shape needed by model is 1,1,128,128,128 bs,ch,h,w,d
326318
softmax_logits = nn.Softmax(dim=1)(logits)
327319
predicted_segmentation = torch.argmax(softmax_logits, 1)
328-
predicted_segmentation = predicted_segmentation.squeeze(dim=0)
329320

330321
print(np.unique(predicted_segmentation, return_counts=True))
331322
# print(pelvic_fracture_ct["image"].shape) # (1, 128, 128, 128)
332-
# print(predicted_segmentation.shape) # (128, 128, 128)
333-
frac_LeftIliac_img, frac_sacrum_img, frac_RightIliac_img = saveDiffFrac(sitk.GetImageFromArray(pelvic_fracture_ct["image"][0]), sitk.GetImageFromArray(predicted_segmentation))
323+
# print(predicted_segmentation.shape) # (1, 128, 128, 128)
324+
frac_LeftIliac_img, frac_sacrum_img, frac_RightIliac_img = saveDiffFrac(sitk.GetImageFromArray(pelvic_fracture_ct["image"][0]), sitk.GetImageFromArray(predicted_segmentation[0]))
334325

335326
sys.stdout.write('<----------Anatomical Model baseline unet completed---------------------> \n')
336327
# print('<----------Anatomical Model baseline unet completed--------------------->')

0 commit comments

Comments
 (0)