Skip to content

Commit 19729b5

Browse files
author
Aaron
committed
1. added sample.py to sample from a trained model.
2. --num_ens sets the number of ensembles. --save_uncertainty to save the std of the ensembles
1 parent 1fc2f82 commit 19729b5

File tree

3 files changed

+122
-4
lines changed

3 files changed

+122
-4
lines changed

driver.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def parse_args():
4848
parser.add_argument('-bs', '--batch_size', type=int, default=8, help='batch size to train on (default: 8)')
4949
parser.add_argument('--timesteps', type=int, default=1000, help='number of timesteps (default: 1000)')
5050
parser.add_argument('-ds', '--dataset', default='generic', help='Dataset to use')
51-
parser.add_argument('--save_every', type=int, default=100, help='save_every n rpochs (default: 100)')
51+
parser.add_argument('--save_every', type=int, default=100, help='save_every n epochs (default: 100)')
5252
parser.add_argument('--load_model_from', default=None, help='path to pt file to load from')
5353
return parser.parse_args()
5454

med_seg_diff_pytorch/dataset.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,11 @@ def __init__(self, directory: str, transform, test_flag: bool = True):
6666
self.directory = os.path.expanduser(directory)
6767
self.transform = transform
6868
self.test_flag = test_flag
69-
self.filenames = [os.path.join(self.directory, x) for x in os.listdir(self.directory) if x.endswith('.npy')]
69+
self.filenames = [x for x in os.listdir(self.directory) if x.endswith('.npy')]
7070

7171
def __getitem__(self, x: int):
7272
fname = self.filenames[x]
73-
npy_img = np.load(fname)
73+
npy_img = np.load(os.path.join(self.directory, fname))
7474
img = npy_img[:, :, :1]
7575
img = torch.from_numpy(img).permute(2, 0, 1)
7676
mask = npy_img[:, :, 1:]
@@ -84,7 +84,8 @@ def __getitem__(self, x: int):
8484
image = self.transform(image)
8585
torch.set_rng_state(state)
8686
mask = self.transform(mask)
87-
87+
if self.test_flag:
88+
return image, mask, fname
8889
return image, mask
8990

9091
def __len__(self) -> int:

sample.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
import os
2+
import argparse
3+
from tqdm import tqdm
4+
import torch
5+
import torchvision.transforms as transforms
6+
from med_seg_diff_pytorch import Unet, MedSegDiff
7+
from med_seg_diff_pytorch.dataset import ISICDataset, GenericNpyDataset
8+
from accelerate import Accelerator
9+
import skimage.io as io
10+
11+
12+
13+
## Parse CLI arguments ##
14+
def parse_args():
15+
parser = argparse.ArgumentParser()
16+
parser.add_argument('-od', '--output_dir', type=str, default="output", help="Output dir.")
17+
parser.add_argument('-ld', '--logging_dir', type=str, default="logs", help="Logging dir.")
18+
parser.add_argument('-mp', '--mixed_precision', type=str, default="no", choices=["no", "fp16", "bf16"],
19+
help="Whether to do mixed precision")
20+
parser.add_argument('-img', '--img_folder', type=str, default='ISBI2016_ISIC_Part3B_Training_Data',
21+
help='The image file path from data_path')
22+
parser.add_argument('-csv', '--csv_file', type=str, default='ISBI2016_ISIC_Part3B_Training_GroundTruth.csv',
23+
help='The csv file to load in from data_path')
24+
parser.add_argument('-sc', '--self_condition', action='store_true', help='Whether to do self condition')
25+
parser.add_argument('-ic', '--mask_channels', type=int, default=1, help='input channels for training (default: 3)')
26+
parser.add_argument('-c', '--input_img_channels', type=int, default=3,
27+
help='output channels for training (default: 3)')
28+
parser.add_argument('-is', '--image_size', type=int, default=128, help='input image size (default: 128)')
29+
parser.add_argument('-dd', '--data_path', default='./data', help='directory of input image')
30+
parser.add_argument('-d', '--dim', type=int, default=64, help='dim (default: 64)')
31+
parser.add_argument('-e', '--epochs', type=int, default=10000, help='number of epochs (default: 10000)')
32+
parser.add_argument('-bs', '--batch_size', type=int, default=8, help='batch size to train on (default: 8)')
33+
parser.add_argument('--timesteps', type=int, default=1000, help='number of timesteps (default: 1000)')
34+
parser.add_argument('-ds', '--dataset', default='generic', help='Dataset to use')
35+
parser.add_argument('--save_every', type=int, default=100, help='save_every n epochs (default: 100)')
36+
parser.add_argument('--num_ens', type=int, default=5,
37+
help='number of times to sample to make an ensable of predictions like in the paper (default: 5)')
38+
parser.add_argument('--load_model_from', default=None, help='path to pt file to load from')
39+
parser.add_argument('--save_uncertainty', action='store_true',
40+
help='Whether to store the uncertainty in predictions (only works for ensablmes)')
41+
42+
return parser.parse_args()
43+
44+
45+
def load_data(args):
46+
# Load dataset
47+
if args.dataset == 'ISIC':
48+
transform_list = [transforms.Resize((args.image_size, args.image_size)), transforms.ToTensor(), ]
49+
transform_train = transforms.Compose(transform_list)
50+
dataset = ISICDataset(args.data_path, args.csv_file, args.img_folder, transform=transform_train, training=False,
51+
flip_p=0.5)
52+
elif args.dataset == 'generic':
53+
transform_list = [transforms.ToPILImage(), transforms.Resize(args.image_size), transforms.ToTensor()]
54+
transform_train = transforms.Compose(transform_list)
55+
dataset = GenericNpyDataset(args.data_path, transform=transform_train, test_flag=True)
56+
else:
57+
raise NotImplementedError(f"Your dataset {args.dataset} hasn't been implemented yet.")
58+
59+
## Define PyTorch data generator
60+
training_generator = torch.utils.data.DataLoader(
61+
dataset,
62+
batch_size=args.batch_size,
63+
shuffle=False)
64+
65+
return training_generator
66+
67+
68+
def main():
69+
args = parse_args()
70+
logging_dir = os.path.join(args.output_dir, args.logging_dir)
71+
inference_dir = os.path.join(args.output_dir, 'inference')
72+
os.makedirs(inference_dir, exist_ok=True)
73+
accelerator = Accelerator(
74+
mixed_precision=args.mixed_precision,
75+
)
76+
# if accelerator.is_main_process:
77+
# accelerator.init_trackers("med-seg-diff", config=vars(args))
78+
79+
## DEFINE MODEL ##
80+
model = Unet(
81+
dim=args.dim,
82+
image_size=args.image_size,
83+
dim_mults=(1, 2, 4, 8),
84+
mask_channels=args.mask_channels,
85+
input_img_channels=args.input_img_channels,
86+
self_condition=args.self_condition
87+
)
88+
89+
## LOAD DATA ##
90+
data_loader = load_data(args)
91+
92+
diffusion = MedSegDiff(
93+
model,
94+
timesteps=args.timesteps
95+
).to(accelerator.device)
96+
97+
if args.load_model_from is not None:
98+
save_dict = torch.load(args.load_model_from)
99+
diffusion.model.load_state_dict(save_dict['model_state_dict'])
100+
101+
for (imgs, masks, fnames) in tqdm(data_loader):
102+
# pre allocate preds
103+
preds = torch.zeros((imgs.shape[0], args.num_ens, imgs.shape[2], imgs.shape[3]))
104+
for i in range(args.num_ens):
105+
preds[:, i:i+1, :, :] = diffusion.sample(imgs).cpu().detach()
106+
preds_mean = preds.mean(dim=1)
107+
preds_std = preds.std(dim=1)
108+
109+
for idx in range(preds.shape[0]):
110+
io.imsave(os.path.join(inference_dir, fnames[idx].replace('.npy', '.png')), preds_mean[idx, :, :])
111+
if args.save_uncertainty:
112+
io.imsave(os.path.join(inference_dir, fnames[idx].replace('.npy', '_std.png')), preds_std[idx, :, :])
113+
114+
115+
116+
if __name__ == '__main__':
117+
main()

0 commit comments

Comments
 (0)