How to perform effective 3D segmentation of small targets in medical imaging. #7219
-
Thank you for the powerful tools provided by MONAI for training medical imaging. I greatly enjoy using them in my research. Currently, I am working on segmenting two targets (T: primary cancer cells, N: metastatic cancer cells) in PET CT brain images. I am facing two challenges: I am using the SwinUNETR model for training with 5-fold cross-validation on 98 sets of data (78 for training, 20 for testing). The Dice scores are approximately around 0.45. I would appreciate some suggestions for improvement. Thanks in advance! The following describes the content of the data:All images have the same spacing in their metadata. I used the original spacing during the transform process.
Training Transform:tr_transform = Compose([
LoadImaged(keys=['image','label'], reader='ITKReader', image_only=False),
EnsureChannelFirstd(keys=['image','label']),
Orientationd(keys=['image','label'], axcodes='RAS'),
Spacingd(keys=['image','label'], pixdim=(2.734, 2.734, 2.79), mode=('bilinear', 'nearest')),
SpatialPadd(keys=['image','label'], spatial_size=(96,96,96)),
EnsureTyped(keys=['image','label'], device='cuda'),
RandFlipd(keys=['image','label'], prob=0.5, spatial_axis=0),
RandFlipd(keys=['image','label'], prob=0.5, spatial_axis=1),
RandFlipd(keys=['image','label'], prob=0.5, spatial_axis=2),
RandSpatialCropd(
keys=['image','label'],
roi_size=(32,32,32),
max_roi_size=(64,64,64),
random_size=True,
random_center=True),
SpatialPadd(keys=['image','label'], spatial_size=(64,64,64)),
NormalizeIntensityd(keys='image', nonzero=True, channel_wise=True),
RandScaleIntensityd(keys='image', factors=0.1, prob=1.0),
RandShiftIntensityd(keys='image', offsets=0.1, prob=1.0),
]) Test Transform:te_transform = Compose([
LoadImaged(keys=['image','label'], reader='ITKReader', image_only=False),
EnsureChannelFirstd(keys=['image','label']),
Orientationd(keys=['image','label'], axcodes='RAS'),
Spacingd(keys=['image','label'], pixdim=(2.734, 2.734, 2.79), mode=('bilinear', 'nearest')),
NormalizeIntensityd(keys='image', nonzero=True, channel_wise=True),
EnsureTyped(keys=['image','label'], device='cuda')
]) Post Transform:post_pred = Compose([AsDiscrete(argmax=True, to_onehot=3)])
post_label = Compose([AsDiscrete(to_onehot=3)]) Define Loss and Metricdice_metric = DiceMetric(include_background=False, reduction='mean_batch')
loss = DiceCELoss(to_onehot_y=True, softmax=True, squared_pred=True, batch=True, include_background=True) Define Model:model = SwinUNETR(
img_size=(64,64,64),
in_channels=1,
out_channels=3,
feature_size=48,
drop_rate=0.2,
attn_drop_rate=0.2,
dropout_path_rate=0.2,
use_checkpoint=True) Define training parameters:max_epochs=1000
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_epochs)
EarlyStopping(monitor=dice_metric, mode='max', patience=100) Define validation methods:sliding_window_inference(image, roi_size=(64,64,64), sw_batch_size=4, model=model, overlap=0.5) Results:Visualization of validation data results: |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 5 replies
-
Hi @Minxiangliu, thanks for your interest here. MONAI/monai/transforms/croppad/array.py Line 1049 in 782c1e6 You can also try to add some weight in the loss function. MONAI/monai/losses/focal_loss.py Line 86 in 782c1e6 Hope it helps, thanks! |
Beta Was this translation helpful? Give feedback.
-
Hi @KumoLiu ,
Can I ignore this? To reiterate, the values in my labels range from 0 to 2. Additionally, after using Code:from monai.data import Dataset, DataLoader
from monai.transforms import (
Compose,
LoadImaged,
EnsureChannelFirstd,
Orientationd,
Spacingd,
RandCropByPosNegLabeld
)
def getDataLoader(filelist:list, trans:list):
dataset = Dataset(data=filelist, transform=Compose(trans))
return DataLoader(dataset, batch_size=1, shuffle=False, drop_last=False, num_workers=0)
trans = [
LoadImaged(keys=['image','label'], reader='ITKReader', image_only=False),
EnsureChannelFirstd(keys=['image','label']),
Orientationd(keys=['image','label'], axcodes='RAS'),
MaskTrans(keys='label'),
Spacingd(keys=['image','label'], pixdim=(2.734, 2.734, 2.79), mode=('bilinear', 'nearest')),
RandCropByPosNegLabeld(
keys=['image', 'label'],
label_key='label',
spatial_size=(64,64,64),
pos=1,
neg=1,
num_samples=4,
image_key='label',
image_threshold=0,
allow_smaller=True)
]
filelist = [{'image': 'image_01.nii.gz', 'label':'label_01.nii.gz'}]
dataloader = getDataLoader(filelist, trans)
for batch in dataloader:
image, label = batch['image'].squeeze(), batch['label'].squeeze() |
Beta Was this translation helpful? Give feedback.
Hi @Minxiangliu, thanks for your interest here.
For the imbalanced issue, I suggest you use some sampling tips, such as cropping the patches based on the pos and neg ratio.
MONAI/monai/transforms/croppad/array.py
Line 1049 in 782c1e6
You can also try to add some weight in the loss function.
MONAI/monai/losses/focal_loss.py
Line 86 in 782c1e6
Hope it helps, thanks!