Skip to content

Commit c96a72e

Browse files
author
Aaron
committed
1. added generic dataset loader
2. fix driver.py timesteps=args.timesteps instead of args.epochs 3. add save model in driver.py 4. add load model in driver.py 5. fix wandb image logger that requires channels last or no channels. 6. added ability to evaluate and save model every n epochs to train faster.
1 parent fe138f8 commit c96a72e

File tree

2 files changed

+125
-48
lines changed

2 files changed

+125
-48
lines changed

driver.py

Lines changed: 75 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,56 +1,69 @@
11
import os
22
import argparse
33
from tqdm import tqdm
4-
54
import torch
5+
import numpy as np
66
import torchvision.transforms as transforms
7-
87
from torch.optim import AdamW
98
from lion_pytorch import Lion
10-
119
from med_seg_diff_pytorch import Unet, MedSegDiff
12-
from med_seg_diff_pytorch.dataset import ISICDataset
13-
10+
from med_seg_diff_pytorch.dataset import ISICDataset, GenericNpyDataset
1411
from accelerate import Accelerator
1512
import wandb
1613

1714
## Parse CLI arguments ##
1815
def parse_args():
1916
parser = argparse.ArgumentParser()
2017
parser.add_argument('-slr', '--scale_lr', action='store_true', help="Whether to scale lr.")
21-
parser.add_argument('-rt', '--report_to', type=str, default="wandb", choices=["wandb"], help="Where to log to. Currently only supports wandb")
18+
parser.add_argument('-rt', '--report_to', type=str, default="wandb", choices=["wandb"],
19+
help="Where to log to. Currently only supports wandb")
2220
parser.add_argument('-ld', '--logging_dir', type=str, default="logs", help="Logging dir.")
2321
parser.add_argument('-od', '--output_dir', type=str, default="output", help="Output dir.")
24-
parser.add_argument('-mp', '--mixed_precision', type=str, default="no", choices=["no", "fp16", "bf16"], help="Whether to do mixed precision")
25-
parser.add_argument('-ga', '--gradient_accumulation_steps', type=int, default=4, help="The number of gradient accumulation steps.")
26-
parser.add_argument('-img', '--img_folder', type=str, default='ISBI2016_ISIC_Part3B_Training_Data', help='The image file path from data_path')
27-
parser.add_argument('-csv', '--csv_file', type=str, default='ISBI2016_ISIC_Part3B_Training_GroundTruth.csv', help='The csv file to load in from data_path')
22+
parser.add_argument('-mp', '--mixed_precision', type=str, default="no", choices=["no", "fp16", "bf16"],
23+
help="Whether to do mixed precision")
24+
parser.add_argument('-ga', '--gradient_accumulation_steps', type=int, default=4,
25+
help="The number of gradient accumulation steps.")
26+
parser.add_argument('-img', '--img_folder', type=str, default='ISBI2016_ISIC_Part3B_Training_Data',
27+
help='The image file path from data_path')
28+
parser.add_argument('-csv', '--csv_file', type=str, default='ISBI2016_ISIC_Part3B_Training_GroundTruth.csv',
29+
help='The csv file to load in from data_path')
2830
parser.add_argument('-sc', '--self_condition', action='store_true', help='Whether to do self condition')
2931
parser.add_argument('-lr', '--learning_rate', type=float, default=5e-4, help='learning rate')
30-
parser.add_argument('-ab1', '--adam_beta1', type=float, default=0.95, help='The beta1 parameter for the Adam optimizer.')
31-
parser.add_argument('-ab2', '--adam_beta2', type=float, default=0.999, help='The beta2 parameter for the Adam optimizer.')
32-
parser.add_argument('-aw', '--adam_weight_decay', type=float, default=1e-6, help='Weight decay magnitude for the Adam optimizer.')
33-
parser.add_argument('-ae', '--adam_epsilon', type=float, default=1e-08, help='Epsilon value for the Adam optimizer.')
32+
parser.add_argument('-ab1', '--adam_beta1', type=float, default=0.95,
33+
help='The beta1 parameter for the Adam optimizer.')
34+
parser.add_argument('-ab2', '--adam_beta2', type=float, default=0.999,
35+
help='The beta2 parameter for the Adam optimizer.')
36+
parser.add_argument('-aw', '--adam_weight_decay', type=float, default=1e-6,
37+
help='Weight decay magnitude for the Adam optimizer.')
38+
parser.add_argument('-ae', '--adam_epsilon', type=float, default=1e-08,
39+
help='Epsilon value for the Adam optimizer.')
3440
parser.add_argument('-ul', '--use_lion', type=bool, default=False, help='use Lion optimizer')
3541
parser.add_argument('-ic', '--mask_channels', type=int, default=1, help='input channels for training (default: 3)')
36-
parser.add_argument('-c', '--input_img_channels', type=int, default=3, help='output channels for training (default: 3)')
42+
parser.add_argument('-c', '--input_img_channels', type=int, default=3,
43+
help='output channels for training (default: 3)')
3744
parser.add_argument('-is', '--image_size', type=int, default=128, help='input image size (default: 128)')
3845
parser.add_argument('-dd', '--data_path', default='./data', help='directory of input image')
39-
parser.add_argument('-d', '--dim', type=int, default=64, help='dim (deaault: 64)')
40-
parser.add_argument('-e', '--epochs', type=int, default=10, help='number of epochs (default: 128)')
46+
parser.add_argument('-d', '--dim', type=int, default=64, help='dim (default: 64)')
47+
parser.add_argument('-e', '--epochs', type=int, default=10000, help='number of epochs (default: 10000)')
4148
parser.add_argument('-bs', '--batch_size', type=int, default=8, help='batch size to train on (default: 8)')
42-
parser.add_argument('-ds', '--dataset', default='ISIC', help='Dataset to use')
49+
parser.add_argument('--timesteps', type=int, default=1000, help='number of timesteps (default: 1000)')
50+
parser.add_argument('-ds', '--dataset', default='generic', help='Dataset to use')
51+
parser.add_argument('--save_every', type=int, default=100, help='save_every n rpochs (default: 100)')
52+
parser.add_argument('--load_model_from', default=None, help='path to pt file to load from')
4353
return parser.parse_args()
4454

4555

4656
def load_data(args):
47-
# Create transforms for data
48-
transform_list = [transforms.Resize((args.image_size,args.image_size)), transforms.ToTensor(),]
49-
transform_train = transforms.Compose(transform_list)
50-
5157
# Load dataset
5258
if args.dataset == 'ISIC':
53-
dataset = ISICDataset(args.data_path, args.csv_file, args.img_folder, transform = transform_train, training = True, flip_p=0.5)
59+
transform_list = [transforms.Resize((args.image_size, args.image_size)), transforms.ToTensor(), ]
60+
transform_train = transforms.Compose(transform_list)
61+
dataset = ISICDataset(args.data_path, args.csv_file, args.img_folder, transform=transform_train, training=True,
62+
flip_p=0.5)
63+
elif args.dataset == 'generic':
64+
transform_list = [transforms.ToPILImage(), transforms.Resize(args.image_size), transforms.ToTensor()]
65+
transform_train = transforms.Compose(transform_list)
66+
dataset = GenericNpyDataset(args.data_path, transform=transform_train, test_flag=False)
5467
else:
5568
raise NotImplementedError(f"Your dataset {args.dataset} hasn't been implemented yet.")
5669

@@ -63,10 +76,11 @@ def load_data(args):
6376
return training_generator
6477

6578

66-
6779
def main():
6880
args = parse_args()
81+
checkpoint_dir = os.path.join(args.output_dir, 'checkpoints')
6982
logging_dir = os.path.join(args.output_dir, args.logging_dir)
83+
os.makedirs(checkpoint_dir, exist_ok=True)
7084
accelerator = Accelerator(
7185
gradient_accumulation_steps=args.gradient_accumulation_steps,
7286
mixed_precision=args.mixed_precision,
@@ -78,20 +92,20 @@ def main():
7892

7993
## DEFINE MODEL ##
8094
model = Unet(
81-
dim = args.dim,
82-
image_size = args.image_size,
83-
dim_mults = (1, 2, 4, 8),
84-
mask_channels = args.mask_channels,
85-
input_img_channels= args.input_img_channels,
86-
self_condition = args.self_condition
95+
dim=args.dim,
96+
image_size=args.image_size,
97+
dim_mults=(1, 2, 4, 8),
98+
mask_channels=args.mask_channels,
99+
input_img_channels=args.input_img_channels,
100+
self_condition=args.self_condition
87101
)
88102

89103
## LOAD DATA ##
90104
data_loader = load_data(args)
91-
#training_generator = tqdm(data_loader, total=int(len(data_loader)))
105+
# training_generator = tqdm(data_loader, total=int(len(data_loader)))
92106
if args.scale_lr:
93107
args.learning_rate = (
94-
args.learning_rate * args.gradient_accumulation_steps * args.batch_size * accelerator.num_processes
108+
args.learning_rate * args.gradient_accumulation_steps * args.batch_size * accelerator.num_processes
95109
)
96110

97111
## Initialize optimizer
@@ -112,22 +126,29 @@ def main():
112126
)
113127

114128
## TRAIN MODEL ##
115-
running_loss = 0.0
116129
counter = 0
117130
model, optimizer, data_loader = accelerator.prepare(
118131
model, optimizer, data_loader
119132
)
120133
diffusion = MedSegDiff(
121134
model,
122-
timesteps = args.epochs
135+
timesteps=args.timesteps
123136
).to(accelerator.device)
137+
138+
if args.load_model_from is not None:
139+
save_dict = torch.load(args.load_model_from)
140+
diffusion.model.load_state_dict(save_dict['model_state_dict'])
141+
optimizer.load_state_dict(save_dict['optimizer_state_dict'])
142+
accelerator.print(f'Loaded from {args.load_model_from}')
143+
124144
## Iterate across training loop
125145
for epoch in range(args.epochs):
126-
print('Epoch {}/{}'.format(epoch+1, args.epochs))
146+
running_loss = 0.0
147+
print('Epoch {}/{}'.format(epoch + 1, args.epochs))
127148
for (img, mask) in tqdm(data_loader):
128149
with accelerator.accumulate(model):
129150
loss = diffusion(mask, img)
130-
accelerator.log({'loss': loss}) # Log loss to wandb
151+
accelerator.log({'loss': loss}) # Log loss to wandb
131152
accelerator.backward(loss)
132153
optimizer.step()
133154
optimizer.zero_grad()
@@ -136,12 +157,24 @@ def main():
136157
epoch_loss = running_loss / len(data_loader)
137158
print('Training Loss : {:.4f}'.format(epoch_loss))
138159
## INFERENCE ##
139-
pred = diffusion.sample(img).cpu().detach().numpy()
140-
for tracker in accelerator.trackers:
141-
if tracker.name == "wandb":
142-
tracker.log(
143-
{'pred-img-mask': [wandb.Image(pred), wandb.Image(img), wandb.Image(mask)]}
144-
)
160+
161+
if epoch % args.save_every == 0:
162+
torch.save({
163+
'epoch': epoch,
164+
'model_state_dict': diffusion.model.state_dict(),
165+
'optimizer_state_dict': optimizer.state_dict(),
166+
'loss': loss,
167+
}, os.path.join(checkpoint_dir, f'state_dict_epoch_{epoch}_loss_{epoch_loss}.pt'))
168+
169+
pred = diffusion.sample(img).cpu().detach().numpy()
170+
171+
for tracker in accelerator.trackers:
172+
if tracker.name == "wandb":
173+
# save just one image per batch
174+
tracker.log(
175+
{'pred-img-mask': [wandb.Image(pred[0, 0, :, :]), wandb.Image(img[0, 0, :, :]),
176+
wandb.Image(mask[0, 0, :, :])]}
177+
)
145178

146179

147180
if __name__ == '__main__':

med_seg_diff_pytorch/dataset.py

Lines changed: 50 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,31 @@
11
import os
2-
os.environ['KMP_DUPLICATE_LIB_OK']='True'
2+
import numpy as np
3+
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
34
import torch
45
from torch.utils.data import Dataset
56
from PIL import Image
67
import pandas as pd
78
import random
89
import torchvision.transforms.functional as F
10+
11+
912
class ISICDataset(Dataset):
10-
def __init__(self, data_path, csv_file, img_folder, transform = None, training = True, flip_p=0.5):
13+
def __init__(self, data_path, csv_file, img_folder, transform=None, training=True, flip_p=0.5):
1114
df = pd.read_csv(os.path.join(data_path, csv_file), encoding='gbk')
1215
self.img_folder = img_folder
13-
self.name_list = df.iloc[:,0].tolist()
14-
self.label_list = df.iloc[:,1].tolist()
16+
self.name_list = df.iloc[:, 0].tolist()
17+
self.label_list = df.iloc[:, 1].tolist()
1518
self.data_path = data_path
1619
self.transform = transform
1720
self.training = training
1821
self.flip_p = flip_p
22+
1923
def __len__(self):
2024
return len(self.name_list)
25+
2126
def __getitem__(self, index):
2227
"""Get the images"""
23-
name = self.name_list[index]+'.jpg'
28+
name = self.name_list[index] + '.jpg'
2429
img_path = os.path.join(self.data_path, self.img_folder, name)
2530

2631
mask_name = name.split('.')[0] + '_Segmentation.png'
@@ -35,13 +40,52 @@ def __getitem__(self, index):
3540
label = int(self.label_list[index])
3641

3742
if self.transform:
43+
# save random state so that if more elaborate transforms are used
44+
# the same transform will be applied to both the mask and the img
45+
state = torch.get_rng_state()
3846
img = self.transform(img)
47+
torch.set_rng_state(state)
3948
mask = self.transform(mask)
4049
if random.random() < self.flip_p:
4150
img = F.vflip(img)
4251
mask = F.vflip(mask)
4352

44-
4553
if self.training:
4654
return (img, mask)
4755
return (img, mask, label)
56+
57+
58+
class GenericNpyDataset(torch.utils.data.Dataset):
59+
def __init__(self, directory: str, transform, test_flag: bool = True):
60+
'''
61+
Genereic dataset for loading npy files.
62+
The npy store 3D arrays with the first two dimensions being the image and the third dimension being the channels.
63+
channel 0 is the image and the other channel is the label.
64+
'''
65+
super().__init__()
66+
self.directory = os.path.expanduser(directory)
67+
self.transform = transform
68+
self.test_flag = test_flag
69+
self.filenames = [os.path.join(self.directory, x) for x in os.listdir(self.directory) if x.endswith('.npy')]
70+
71+
def __getitem__(self, x: int):
72+
fname = self.filenames[x]
73+
npy_img = np.load(fname)
74+
img = npy_img[:, :, :1]
75+
img = torch.from_numpy(img).permute(2, 0, 1)
76+
mask = npy_img[:, :, 1:]
77+
mask = np.where(mask > 0, 1, 0)
78+
image = img[:, ...]
79+
mask = torch.from_numpy(mask).permute(2, 0, 1).float()
80+
if self.transform:
81+
# save random state so that if more elaborate transforms are used
82+
# the same transform will be applied to both the mask and the img
83+
state = torch.get_rng_state()
84+
image = self.transform(image)
85+
torch.set_rng_state(state)
86+
mask = self.transform(mask)
87+
88+
return image, mask
89+
90+
def __len__(self) -> int:
91+
return len(self.filenames)

0 commit comments

Comments
 (0)