|
| 1 | +import os |
| 2 | +import argparse |
| 3 | +from tqdm import tqdm |
| 4 | +import torch |
| 5 | +import torchvision.transforms as transforms |
| 6 | +from med_seg_diff_pytorch import Unet, MedSegDiff |
| 7 | +from med_seg_diff_pytorch.dataset import ISICDataset, GenericNpyDataset |
| 8 | +from accelerate import Accelerator |
| 9 | +import skimage.io as io |
| 10 | + |
| 11 | + |
| 12 | + |
| 13 | +## Parse CLI arguments ## |
| 14 | +def parse_args(): |
| 15 | + parser = argparse.ArgumentParser() |
| 16 | + parser.add_argument('-od', '--output_dir', type=str, default="output", help="Output dir.") |
| 17 | + parser.add_argument('-ld', '--logging_dir', type=str, default="logs", help="Logging dir.") |
| 18 | + parser.add_argument('-mp', '--mixed_precision', type=str, default="no", choices=["no", "fp16", "bf16"], |
| 19 | + help="Whether to do mixed precision") |
| 20 | + parser.add_argument('-img', '--img_folder', type=str, default='ISBI2016_ISIC_Part3B_Training_Data', |
| 21 | + help='The image file path from data_path') |
| 22 | + parser.add_argument('-csv', '--csv_file', type=str, default='ISBI2016_ISIC_Part3B_Training_GroundTruth.csv', |
| 23 | + help='The csv file to load in from data_path') |
| 24 | + parser.add_argument('-sc', '--self_condition', action='store_true', help='Whether to do self condition') |
| 25 | + parser.add_argument('-ic', '--mask_channels', type=int, default=1, help='input channels for training (default: 3)') |
| 26 | + parser.add_argument('-c', '--input_img_channels', type=int, default=3, |
| 27 | + help='output channels for training (default: 3)') |
| 28 | + parser.add_argument('-is', '--image_size', type=int, default=128, help='input image size (default: 128)') |
| 29 | + parser.add_argument('-dd', '--data_path', default='./data', help='directory of input image') |
| 30 | + parser.add_argument('-d', '--dim', type=int, default=64, help='dim (default: 64)') |
| 31 | + parser.add_argument('-e', '--epochs', type=int, default=10000, help='number of epochs (default: 10000)') |
| 32 | + parser.add_argument('-bs', '--batch_size', type=int, default=8, help='batch size to train on (default: 8)') |
| 33 | + parser.add_argument('--timesteps', type=int, default=1000, help='number of timesteps (default: 1000)') |
| 34 | + parser.add_argument('-ds', '--dataset', default='generic', help='Dataset to use') |
| 35 | + parser.add_argument('--save_every', type=int, default=100, help='save_every n epochs (default: 100)') |
| 36 | + parser.add_argument('--num_ens', type=int, default=5, |
| 37 | + help='number of times to sample to make an ensable of predictions like in the paper (default: 5)') |
| 38 | + parser.add_argument('--load_model_from', default=None, help='path to pt file to load from') |
| 39 | + parser.add_argument('--save_uncertainty', action='store_true', |
| 40 | + help='Whether to store the uncertainty in predictions (only works for ensablmes)') |
| 41 | + |
| 42 | + return parser.parse_args() |
| 43 | + |
| 44 | + |
| 45 | +def load_data(args): |
| 46 | + # Load dataset |
| 47 | + if args.dataset == 'ISIC': |
| 48 | + transform_list = [transforms.Resize((args.image_size, args.image_size)), transforms.ToTensor(), ] |
| 49 | + transform_train = transforms.Compose(transform_list) |
| 50 | + dataset = ISICDataset(args.data_path, args.csv_file, args.img_folder, transform=transform_train, training=False, |
| 51 | + flip_p=0.5) |
| 52 | + elif args.dataset == 'generic': |
| 53 | + transform_list = [transforms.ToPILImage(), transforms.Resize(args.image_size), transforms.ToTensor()] |
| 54 | + transform_train = transforms.Compose(transform_list) |
| 55 | + dataset = GenericNpyDataset(args.data_path, transform=transform_train, test_flag=True) |
| 56 | + else: |
| 57 | + raise NotImplementedError(f"Your dataset {args.dataset} hasn't been implemented yet.") |
| 58 | + |
| 59 | + ## Define PyTorch data generator |
| 60 | + training_generator = torch.utils.data.DataLoader( |
| 61 | + dataset, |
| 62 | + batch_size=args.batch_size, |
| 63 | + shuffle=False) |
| 64 | + |
| 65 | + return training_generator |
| 66 | + |
| 67 | + |
| 68 | +def main(): |
| 69 | + args = parse_args() |
| 70 | + logging_dir = os.path.join(args.output_dir, args.logging_dir) |
| 71 | + inference_dir = os.path.join(args.output_dir, 'inference') |
| 72 | + os.makedirs(inference_dir, exist_ok=True) |
| 73 | + accelerator = Accelerator( |
| 74 | + mixed_precision=args.mixed_precision, |
| 75 | + ) |
| 76 | + # if accelerator.is_main_process: |
| 77 | + # accelerator.init_trackers("med-seg-diff", config=vars(args)) |
| 78 | + |
| 79 | + ## DEFINE MODEL ## |
| 80 | + model = Unet( |
| 81 | + dim=args.dim, |
| 82 | + image_size=args.image_size, |
| 83 | + dim_mults=(1, 2, 4, 8), |
| 84 | + mask_channels=args.mask_channels, |
| 85 | + input_img_channels=args.input_img_channels, |
| 86 | + self_condition=args.self_condition |
| 87 | + ) |
| 88 | + |
| 89 | + ## LOAD DATA ## |
| 90 | + data_loader = load_data(args) |
| 91 | + |
| 92 | + diffusion = MedSegDiff( |
| 93 | + model, |
| 94 | + timesteps=args.timesteps |
| 95 | + ).to(accelerator.device) |
| 96 | + |
| 97 | + if args.load_model_from is not None: |
| 98 | + save_dict = torch.load(args.load_model_from) |
| 99 | + diffusion.model.load_state_dict(save_dict['model_state_dict']) |
| 100 | + |
| 101 | + for (imgs, masks, fnames) in tqdm(data_loader): |
| 102 | + # pre allocate preds |
| 103 | + preds = torch.zeros((imgs.shape[0], args.num_ens, imgs.shape[2], imgs.shape[3])) |
| 104 | + for i in range(args.num_ens): |
| 105 | + preds[:, i:i+1, :, :] = diffusion.sample(imgs).cpu().detach() |
| 106 | + preds_mean = preds.mean(dim=1) |
| 107 | + preds_std = preds.std(dim=1) |
| 108 | + |
| 109 | + for idx in range(preds.shape[0]): |
| 110 | + io.imsave(os.path.join(inference_dir, fnames[idx].replace('.npy', '.png')), preds_mean[idx, :, :]) |
| 111 | + if args.save_uncertainty: |
| 112 | + io.imsave(os.path.join(inference_dir, fnames[idx].replace('.npy', '_std.png')), preds_std[idx, :, :]) |
| 113 | + |
| 114 | + |
| 115 | + |
| 116 | +if __name__ == '__main__': |
| 117 | + main() |
0 commit comments