Replies: 3 comments 6 replies
-
Hi @Gijz33 , Thanks for your interest here. Thanks. |
Beta Was this translation helpful? Give feedback.
-
Hi @Nic-Ma, I'm sorry for my late response. # Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import glob
import logging
import os
import shutil
import sys
import numpy as np
import torch
import torch.nn as nn
from ignite.contrib.handlers import ProgressBar
import monai
from monai.handlers import CheckpointSaver, MeanDice, StatsHandler, ValidationHandler, from_engine
from monai.transforms import (
AddChanneld,
AsDiscreted,
CastToTyped,
LoadImaged,
Orientationd,
RandAffined,
RandCropByPosNegLabeld,
RandFlipd,
RandGaussianNoised,
ScaleIntensityRanged,
Spacingd,
SpatialPadd,
EnsureTyped,
)
def get_xforms(mode="train", keys=("image", "label")):
"""returns a composed transform for train/val/infer."""
xforms = [
LoadImaged(keys),
AddChanneld(keys),
Orientationd(keys, axcodes="LPS"),
Spacingd(keys, pixdim=(1.25, 1.25, 5.0), mode=("bilinear", "nearest")[: len(keys)]),
ScaleIntensityRanged(keys[0], a_min=-1000.0, a_max=500.0, b_min=0.0, b_max=1.0, clip=True),
]
if mode == "train":
xforms.extend(
[
SpatialPadd(keys, spatial_size=(192, 192, -1), mode="reflect"), # ensure at least 192x192
RandAffined(
keys,
prob=0.15,
rotate_range=(0.05, 0.05, None), # 3 parameters control the transform on 3 dimensions
scale_range=(0.1, 0.1, None),
mode=("bilinear", "nearest"),
as_tensor_output=False,
),
RandCropByPosNegLabeld(keys, label_key=keys[1], spatial_size=(192, 192, 16), num_samples=3),
RandGaussianNoised(keys[0], prob=0.15, std=0.01),
RandFlipd(keys, spatial_axis=0, prob=0.5),
RandFlipd(keys, spatial_axis=1, prob=0.5),
RandFlipd(keys, spatial_axis=2, prob=0.5),
]
)
dtype = (np.float32, np.uint8)
if mode == "val":
dtype = (np.float32, np.uint8)
if mode == "infer":
dtype = (np.float32,)
xforms.extend([CastToTyped(keys, dtype=dtype), EnsureTyped(keys)])
return monai.transforms.Compose(xforms)
def get_net():
"""returns a unet model instance."""
n_classes=2
net=monai.networks.nets.UNet(
dimensions=3,
in_channels=1,
out_channels=n_classes,
channels=(16, 32, 64, 128, 256),
strides=(2, 2, 2, 2),
#num_res_units=2,
dropout=0.5,
)
return net
def get_inferer(_mode=None):
"""returns a sliding window inference instance."""
patch_size = (192, 192, 16)
sw_batch_size, overlap = 2, 0.5
inferer = monai.inferers.SlidingWindowInferer(
roi_size=patch_size,
sw_batch_size=sw_batch_size,
overlap=overlap,
mode="gaussian",
padding_mode="replicate",
)
return inferer
class DiceCELoss(nn.Module):
"""Dice and Xentropy loss"""
def __init__(self):
super().__init__()
self.dice = monai.losses.DiceLoss(to_onehot_y=True, softmax=True)
self.cross_entropy = nn.CrossEntropyLoss()
def forward(self, y_pred, y_true):
dice = self.dice(y_pred, y_true)
# CrossEntropyLoss target needs to have shape (B, D, H, W)
# Target from pipeline has shape (B, 1, D, H, W)
cross_entropy = self.cross_entropy(y_pred, torch.squeeze(y_true, dim=1).long())
return dice + cross_entropy
def train(data_folder=".", model_folder="runs"):
"""run a training pipeline."""
images = sorted(glob.glob(os.path.join(data_folder, "*_ct.nii.gz")))
labels = sorted(glob.glob(os.path.join(data_folder, "*_seg.nii.gz")))
logging.info(f"training: image/label ({len(images)}) folder: {data_folder}")
amp = True # auto. mixed precision
keys = ("image", "label")
train_frac, val_frac = 0.8, 0.2
n_train = int(train_frac * len(images)) + 1
n_val = min(len(images) - n_train, int(val_frac * len(images)))
logging.info(f"training: train {n_train} val {n_val}, folder: {data_folder}")
train_files = [{keys[0]: img, keys[1]: seg} for img, seg in zip(images[:n_train], labels[:n_train])]
val_files = [{keys[0]: img, keys[1]: seg} for img, seg in zip(images[-n_val:], labels[-n_val:])]
# create a training data loader
batch_size = 2
logging.info(f"batch size {batch_size}")
train_transforms = get_xforms("train", keys)
train_ds = monai.data.CacheDataset(data=train_files, transform=train_transforms)
train_loader = monai.data.DataLoader(
train_ds,
batch_size=batch_size,
shuffle=True,
num_workers=2,
pin_memory=torch.cuda.is_available(),
)
# create a validation data loader
val_transforms = get_xforms("val", keys)
val_ds = monai.data.CacheDataset(data=val_files, transform=val_transforms)
val_loader = monai.data.DataLoader(
val_ds,
batch_size=1, # image-level batch to the sliding window method, not the window-level batch
num_workers=2,
pin_memory=torch.cuda.is_available(),
)
# create BasicUNet, DiceLoss and Adam optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = get_net().to(device)
max_epochs, lr, momentum = 500, 1e-4, 0.95
logging.info(f"epochs {max_epochs}, lr {lr}, momentum {momentum}")
opt = torch.optim.Adam(net.parameters(), lr=lr)
# create evaluator (to be used to measure model quality during training
val_post_transform = monai.transforms.Compose(
[EnsureTyped(keys=("pred", "label")), AsDiscreted(keys=("pred", "label"), argmax=(True, False), to_onehot=2)]
)
val_handlers = [
ProgressBar(),
CheckpointSaver(save_dir=model_folder, save_dict={"net": net}, save_key_metric=True, key_metric_n_saved=3),
]
evaluator = monai.engines.SupervisedEvaluator(
device=device,
val_data_loader=val_loader,
network=net,
inferer=get_inferer(),
postprocessing=val_post_transform,
key_val_metric={
"val_mean_dice": MeanDice(include_background=False, output_transform=from_engine(["pred", "label"]))
},
val_handlers=val_handlers,
amp=amp,
)
# evaluator as an event handler of the trainer
train_handlers = [
ValidationHandler(validator=evaluator, interval=1, epoch_level=True),
StatsHandler(tag_name="train_loss", output_transform=from_engine(["loss"], first=True)),
]
trainer = monai.engines.SupervisedTrainer(
device=device,
max_epochs=max_epochs,
train_data_loader=train_loader,
network=net,
optimizer=opt,
loss_function=DiceCELoss(),
inferer=get_inferer(),
key_train_metric=None,
train_handlers=train_handlers,
amp=amp,
)
trainer.run()
def infer(data_folder=".", model_folder="runs", prediction_folder="output"):
"""
run inference, the output folder will be "./output"
"""
ckpts = sorted(glob.glob(os.path.join(model_folder, "*.pt")))
ckpt = ckpts[-1]
for x in ckpts:
logging.info(f"available model file: {x}.")
logging.info("----")
logging.info(f"using {ckpt}.")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = get_net().to(device)
net.load_state_dict(torch.load(ckpt, map_location=device))
net.eval()
image_folder = os.path.abspath(data_folder)
images = sorted(glob.glob(os.path.join(image_folder, "*_ct.nii.gz")))
logging.info(f"infer: image ({len(images)}) folder: {data_folder}")
infer_files = [{"image": img} for img in images]
keys = ("image",)
infer_transforms = get_xforms("infer", keys)
infer_ds = monai.data.Dataset(data=infer_files, transform=infer_transforms)
infer_loader = monai.data.DataLoader(
infer_ds,
batch_size=1, # image-level batch to the sliding window method, not the window-level batch
num_workers=2,
pin_memory=torch.cuda.is_available(),
)
inferer = get_inferer()
saver = monai.data.NiftiSaver(output_dir=prediction_folder, mode="nearest")
with torch.no_grad():
for infer_data in infer_loader:
logging.info(f"segmenting {infer_data['image_meta_dict']['filename_or_obj']}")
preds = inferer(infer_data[keys[0]].to(device), net)
n = 1.0
for _ in range(4):
# test time augmentations
_img = RandGaussianNoised(keys[0], prob=1.0, std=0.01)(infer_data)[keys[0]]
pred = inferer(_img.to(device), net)
preds = preds + pred
n = n + 1.0
for dims in [[2], [3]]:
flip_pred = inferer(torch.flip(_img.to(device), dims=dims), net)
pred = torch.flip(flip_pred, dims=dims)
preds = preds + pred
n = n + 1.0
preds = preds / n
preds = (preds.argmax(dim=1, keepdims=True)).float()
saver.save_batch(preds, infer_data["image_meta_dict"])
# copy the saved segmentations into the required folder structure for submission
submission_dir = os.path.join(prediction_folder, "to_submit")
if not os.path.exists(submission_dir):
os.makedirs(submission_dir)
files = glob.glob(os.path.join(prediction_folder, "volume*", "*.nii.gz"))
for f in files:
new_name = os.path.basename(f)
new_name = new_name[len("volume-covid19-A-0"):]
new_name = new_name[: -len("_ct_seg.nii.gz")] + ".nii.gz"
to_name = os.path.join(submission_dir, new_name)
shutil.copy(f, to_name)
logging.info(f"predictions copied to {submission_dir}.")
if __name__ == "__main__":
"""
Usage:
python run_net.py train --data_folder "COVID-19-20_v2/Train" # run the training pipeline
python run_net.py infer --data_folder "COVID-19-20_v2/Validation" # run the inference pipeline
"""
parser = argparse.ArgumentParser(description="Run a basic UNet segmentation baseline.")
parser.add_argument(
"mode", metavar="mode", default="train", choices=("train", "infer"), type=str, help="mode of workflow"
)
parser.add_argument("--data_folder", default="", type=str, help="training data folder")
parser.add_argument("--model_folder", default="runs", type=str, help="model folder")
args = parser.parse_args()
monai.config.print_config()
monai.utils.set_determinism(seed=0)
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
if args.mode == "train":
data_folder = args.data_folder or os.path.join("COVID-19-20_v2", "Train")
train(data_folder=data_folder, model_folder=args.model_folder)
elif args.mode == "infer":
data_folder = args.data_folder or os.path.join("COVID-19-20_v2", "Validation")
infer(data_folder=data_folder, model_folder=args.model_folder)
else:
raise ValueError("Unknown mode.") |
Beta Was this translation helpful? Give feedback.
-
Hi @Nic-Ma , Regarding the SupervisedEvaluator in the model I sent here, I have got a question about its function. Best, |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi!
How does supervisedtrainer handle non-binary labels (between 0 and 1) when a binary label is requested.
Does the model assume labelpixel >0 --> 1. Or does it something like: everything <0.5 --> 0 and >0.5 --> 1.
Beta Was this translation helpful? Give feedback.
All reactions