labels and image not overlapping / error runtime in torch.cat #6861
Unanswered
Striking-Project
asked this question in
Q&A
Replies: 1 comment
-
Hi @Striking-Project, thanks for your interest here.
Hope your problem can be resolved soon. Thanks! |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi @KumoLiu
these are the transformations I am using on data I collected.
num_samples = 2
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_transforms = Compose(
[
LoadImaged(keys=["image", "label"], ensure_channel_first=True),
EnsureTyped(keys=["image", "label"], device=device, track_meta=False),
RandCropByPosNegLabeld(
keys=["image", "label"],
label_key="label",
spatial_size=(96, 96, 96),
pos=1,
neg=1,
num_samples=num_samples,
image_key="image",
image_threshold=0,
allow_smaller=True,
),
]
)
val_transforms = Compose(
[
LoadImaged(keys=["image", "label"], ensure_channel_first=True),
EnsureTyped(keys=["image", "label"], device=device, track_meta=True),
]
)
I'm trying to train the swin_unetr from monai btcv segmentation
the code for plotting the data :
slice_map = {
"ci1clear_PELVIS_20210104171404_11.nii.gz": 19,
"ci2clear_PELVIS_20210104161217_6.nii.gz": 13,
"ci3clear_PELVIS_20210106090948_6.nii.gz": 15,
"ci4clear_PELVIS_ABIR_20210107182848_11.nii.gz": 17,
"ci5clear_PELVIS_20210108135040_9.nii.gz": 21,
"ci6clear_PELVIS_ABIR_20210109094218_7.nii.gz": 29,
"ci7clear_PELVIS_ABIR_20210111122138_8.nii.gz": 21,
"c8clear_PELVIS_ABIR_20210111132001_8.nii.gz": 120,
"c9_PELVIS_ABIR_20210115141302_5.nii.gz": 120,
"c10clear_PELVIS_20210118143626_6.nii.gz": 15
}
case_num = 3
img_name = os.path.split(val_ds[case_num]["image"].meta["filename_or_obj"])[1]
img = val_ds[case_num]["image"]
label = val_ds[case_num]["label"]
img_shape = img.shape
label_shape = label.shape
print(f"image shape: {img_shape}, label shape: {label_shape}")
plt.figure("image", (18, 6))
plt.subplot(1, 2, 1)
plt.title("image")
plt.imshow(img[0, :, :, slice_map[img_name]].detach().cpu(), cmap="gray")
plt.subplot(1, 2, 2)
plt.title("label")
plt.imshow(label[0, :, :, slice_map[img_name]].detach().cpu())
plt.show()
the label needs to be rotated 180 degrees clockwise to overlap on the data. I get the same issue when I convert the niffti segmentation file to png. But It shouldn't be that way.
I also think it's causing me another major problem which the dimensions not matching in torch.cat in the decoder block
RuntimeError Traceback (most recent call last)
in <cell line: 11>()
10 metric_values = []
11 while global_step < max_iterations:
---> 12 global_step, dice_val_best, global_step_best = train(global_step, train_loader, dice_val_best, global_step_best)
13 model.load_state_dict(torch.load(os.path.join(root_dir, "best_metric_model.pth")))
6 frames
in train(global_step, train_loader, dice_val_best, global_step_best)
41 x, y = (batch["image"].cuda(), batch["label"].cuda())
42 with torch.cuda.amp.autocast():
---> 43 logit_map = model(x)
44 loss = loss_function(logit_map, y)
45 scaler.scale(loss).backward()
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
in forward(self, x_in)
306 enc3 = self.encoder4(hidden_states_out[2])
307 dec4 = self.encoder10(hidden_states_out[4])
--> 308 dec3 = self.decoder5(dec4, hidden_states_out[3])
309 dec2 = self.decoder4(dec3, enc3)
310 dec1 = self.decoder3(dec2, enc2)
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
/usr/local/lib/python3.10/dist-packages/monai/networks/blocks/unetr_block.py in forward(self, inp, skip)
82 # number of channels for skip should equals to out_channels
83 out = self.transp_conv(inp)
---> 84
85 if out.shape[1] != skip.shape[1]:
86 skip = torch.nn.functional.interpolate(skip, size=out.shape[2:], mode='nearest')
/usr/local/lib/python3.10/dist-packages/monai/data/meta_tensor.py in torch_function(cls, func, types, args, kwargs)
280 if kwargs is None:
281 kwargs = {}
--> 282 ret = super().torch_function(func, types, args, kwargs)
283 # if
out
has been used as argument, metadata is not copied, nothing to do.284 # if "out" in kwargs:
/usr/local/lib/python3.10/dist-packages/torch/_tensor.py in torch_function(cls, func, types, args, kwargs)
1293
1294 with _C.DisableTorchFunctionSubclass():
-> 1295 ret = func(*args, **kwargs)
1296 if func in get_default_nowrap_functions():
1297 return ret
RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 4 but got size 3 for tensor number 1 in the list.
Beta Was this translation helpful? Give feedback.
All reactions