Skip to content

Commit 3a67e25

Browse files
authored
Merge pull request #7 from bf2harven/main
fixes
2 parents fe138f8 + c96a72e commit 3a67e25

File tree

2 files changed

+125
-48
lines changed

2 files changed

+125
-48
lines changed

driver.py

Lines changed: 75 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,56 +1,69 @@
11
import os
22
import argparse
33
from tqdm import tqdm
4-
54
import torch
5+
import numpy as np
66
import torchvision.transforms as transforms
7-
87
from torch.optim import AdamW
98
from lion_pytorch import Lion
10-
119
from med_seg_diff_pytorch import Unet, MedSegDiff
12-
from med_seg_diff_pytorch.dataset import ISICDataset
13-
10+
from med_seg_diff_pytorch.dataset import ISICDataset, GenericNpyDataset
1411
from accelerate import Accelerator
1512
import wandb
1613

1714
## Parse CLI arguments ##
1815
def parse_args():
1916
parser = argparse.ArgumentParser()
2017
parser.add_argument('-slr', '--scale_lr', action='store_true', help="Whether to scale lr.")
21-
parser.add_argument('-rt', '--report_to', type=str, default="wandb", choices=["wandb"], help="Where to log to. Currently only supports wandb")
18+
parser.add_argument('-rt', '--report_to', type=str, default="wandb", choices=["wandb"],
19+
help="Where to log to. Currently only supports wandb")
2220
parser.add_argument('-ld', '--logging_dir', type=str, default="logs", help="Logging dir.")
2321
parser.add_argument('-od', '--output_dir', type=str, default="output", help="Output dir.")
24-
parser.add_argument('-mp', '--mixed_precision', type=str, default="no", choices=["no", "fp16", "bf16"], help="Whether to do mixed precision")
25-
parser.add_argument('-ga', '--gradient_accumulation_steps', type=int, default=4, help="The number of gradient accumulation steps.")
26-
parser.add_argument('-img', '--img_folder', type=str, default='ISBI2016_ISIC_Part3B_Training_Data', help='The image file path from data_path')
27-
parser.add_argument('-csv', '--csv_file', type=str, default='ISBI2016_ISIC_Part3B_Training_GroundTruth.csv', help='The csv file to load in from data_path')
22+
parser.add_argument('-mp', '--mixed_precision', type=str, default="no", choices=["no", "fp16", "bf16"],
23+
help="Whether to do mixed precision")
24+
parser.add_argument('-ga', '--gradient_accumulation_steps', type=int, default=4,
25+
help="The number of gradient accumulation steps.")
26+
parser.add_argument('-img', '--img_folder', type=str, default='ISBI2016_ISIC_Part3B_Training_Data',
27+
help='The image file path from data_path')
28+
parser.add_argument('-csv', '--csv_file', type=str, default='ISBI2016_ISIC_Part3B_Training_GroundTruth.csv',
29+
help='The csv file to load in from data_path')
2830
parser.add_argument('-sc', '--self_condition', action='store_true', help='Whether to do self condition')
2931
parser.add_argument('-lr', '--learning_rate', type=float, default=5e-4, help='learning rate')
30-
parser.add_argument('-ab1', '--adam_beta1', type=float, default=0.95, help='The beta1 parameter for the Adam optimizer.')
31-
parser.add_argument('-ab2', '--adam_beta2', type=float, default=0.999, help='The beta2 parameter for the Adam optimizer.')
32-
parser.add_argument('-aw', '--adam_weight_decay', type=float, default=1e-6, help='Weight decay magnitude for the Adam optimizer.')
33-
parser.add_argument('-ae', '--adam_epsilon', type=float, default=1e-08, help='Epsilon value for the Adam optimizer.')
32+
parser.add_argument('-ab1', '--adam_beta1', type=float, default=0.95,
33+
help='The beta1 parameter for the Adam optimizer.')
34+
parser.add_argument('-ab2', '--adam_beta2', type=float, default=0.999,
35+
help='The beta2 parameter for the Adam optimizer.')
36+
parser.add_argument('-aw', '--adam_weight_decay', type=float, default=1e-6,
37+
help='Weight decay magnitude for the Adam optimizer.')
38+
parser.add_argument('-ae', '--adam_epsilon', type=float, default=1e-08,
39+
help='Epsilon value for the Adam optimizer.')
3440
parser.add_argument('-ul', '--use_lion', type=bool, default=False, help='use Lion optimizer')
3541
parser.add_argument('-ic', '--mask_channels', type=int, default=1, help='input channels for training (default: 3)')
36-
parser.add_argument('-c', '--input_img_channels', type=int, default=3, help='output channels for training (default: 3)')
42+
parser.add_argument('-c', '--input_img_channels', type=int, default=3,
43+
help='output channels for training (default: 3)')
3744
parser.add_argument('-is', '--image_size', type=int, default=128, help='input image size (default: 128)')
3845
parser.add_argument('-dd', '--data_path', default='./data', help='directory of input image')
39-
parser.add_argument('-d', '--dim', type=int, default=64, help='dim (deaault: 64)')
40-
parser.add_argument('-e', '--epochs', type=int, default=10, help='number of epochs (default: 128)')
46+
parser.add_argument('-d', '--dim', type=int, default=64, help='dim (default: 64)')
47+
parser.add_argument('-e', '--epochs', type=int, default=10000, help='number of epochs (default: 10000)')
4148
parser.add_argument('-bs', '--batch_size', type=int, default=8, help='batch size to train on (default: 8)')
42-
parser.add_argument('-ds', '--dataset', default='ISIC', help='Dataset to use')
49+
parser.add_argument('--timesteps', type=int, default=1000, help='number of timesteps (default: 1000)')
50+
parser.add_argument('-ds', '--dataset', default='generic', help='Dataset to use')
51+
parser.add_argument('--save_every', type=int, default=100, help='save_every n rpochs (default: 100)')
52+
parser.add_argument('--load_model_from', default=None, help='path to pt file to load from')
4353
return parser.parse_args()
4454

4555

4656
def load_data(args):
47-
# Create transforms for data
48-
transform_list = [transforms.Resize((args.image_size,args.image_size)), transforms.ToTensor(),]
49-
transform_train = transforms.Compose(transform_list)
50-
5157
# Load dataset
5258
if args.dataset == 'ISIC':
53-
dataset = ISICDataset(args.data_path, args.csv_file, args.img_folder, transform = transform_train, training = True, flip_p=0.5)
59+
transform_list = [transforms.Resize((args.image_size, args.image_size)), transforms.ToTensor(), ]
60+
transform_train = transforms.Compose(transform_list)
61+
dataset = ISICDataset(args.data_path, args.csv_file, args.img_folder, transform=transform_train, training=True,
62+
flip_p=0.5)
63+
elif args.dataset == 'generic':
64+
transform_list = [transforms.ToPILImage(), transforms.Resize(args.image_size), transforms.ToTensor()]
65+
transform_train = transforms.Compose(transform_list)
66+
dataset = GenericNpyDataset(args.data_path, transform=transform_train, test_flag=False)
5467
else:
5568
raise NotImplementedError(f"Your dataset {args.dataset} hasn't been implemented yet.")
5669

@@ -63,10 +76,11 @@ def load_data(args):
6376
return training_generator
6477

6578

66-
6779
def main():
6880
args = parse_args()
81+
checkpoint_dir = os.path.join(args.output_dir, 'checkpoints')
6982
logging_dir = os.path.join(args.output_dir, args.logging_dir)
83+
os.makedirs(checkpoint_dir, exist_ok=True)
7084
accelerator = Accelerator(
7185
gradient_accumulation_steps=args.gradient_accumulation_steps,
7286
mixed_precision=args.mixed_precision,
@@ -78,20 +92,20 @@ def main():
7892

7993
## DEFINE MODEL ##
8094
model = Unet(
81-
dim = args.dim,
82-
image_size = args.image_size,
83-
dim_mults = (1, 2, 4, 8),
84-
mask_channels = args.mask_channels,
85-
input_img_channels= args.input_img_channels,
86-
self_condition = args.self_condition
95+
dim=args.dim,
96+
image_size=args.image_size,
97+
dim_mults=(1, 2, 4, 8),
98+
mask_channels=args.mask_channels,
99+
input_img_channels=args.input_img_channels,
100+
self_condition=args.self_condition
87101
)
88102

89103
## LOAD DATA ##
90104
data_loader = load_data(args)
91-
#training_generator = tqdm(data_loader, total=int(len(data_loader)))
105+
# training_generator = tqdm(data_loader, total=int(len(data_loader)))
92106
if args.scale_lr:
93107
args.learning_rate = (
94-
args.learning_rate * args.gradient_accumulation_steps * args.batch_size * accelerator.num_processes
108+
args.learning_rate * args.gradient_accumulation_steps * args.batch_size * accelerator.num_processes
95109
)
96110

97111
## Initialize optimizer
@@ -112,22 +126,29 @@ def main():
112126
)
113127

114128
## TRAIN MODEL ##
115-
running_loss = 0.0
116129
counter = 0
117130
model, optimizer, data_loader = accelerator.prepare(
118131
model, optimizer, data_loader
119132
)
120133
diffusion = MedSegDiff(
121134
model,
122-
timesteps = args.epochs
135+
timesteps=args.timesteps
123136
).to(accelerator.device)
137+
138+
if args.load_model_from is not None:
139+
save_dict = torch.load(args.load_model_from)
140+
diffusion.model.load_state_dict(save_dict['model_state_dict'])
141+
optimizer.load_state_dict(save_dict['optimizer_state_dict'])
142+
accelerator.print(f'Loaded from {args.load_model_from}')
143+
124144
## Iterate across training loop
125145
for epoch in range(args.epochs):
126-
print('Epoch {}/{}'.format(epoch+1, args.epochs))
146+
running_loss = 0.0
147+
print('Epoch {}/{}'.format(epoch + 1, args.epochs))
127148
for (img, mask) in tqdm(data_loader):
128149
with accelerator.accumulate(model):
129150
loss = diffusion(mask, img)
130-
accelerator.log({'loss': loss}) # Log loss to wandb
151+
accelerator.log({'loss': loss}) # Log loss to wandb
131152
accelerator.backward(loss)
132153
optimizer.step()
133154
optimizer.zero_grad()
@@ -136,12 +157,24 @@ def main():
136157
epoch_loss = running_loss / len(data_loader)
137158
print('Training Loss : {:.4f}'.format(epoch_loss))
138159
## INFERENCE ##
139-
pred = diffusion.sample(img).cpu().detach().numpy()
140-
for tracker in accelerator.trackers:
141-
if tracker.name == "wandb":
142-
tracker.log(
143-
{'pred-img-mask': [wandb.Image(pred), wandb.Image(img), wandb.Image(mask)]}
144-
)
160+
161+
if epoch % args.save_every == 0:
162+
torch.save({
163+
'epoch': epoch,
164+
'model_state_dict': diffusion.model.state_dict(),
165+
'optimizer_state_dict': optimizer.state_dict(),
166+
'loss': loss,
167+
}, os.path.join(checkpoint_dir, f'state_dict_epoch_{epoch}_loss_{epoch_loss}.pt'))
168+
169+
pred = diffusion.sample(img).cpu().detach().numpy()
170+
171+
for tracker in accelerator.trackers:
172+
if tracker.name == "wandb":
173+
# save just one image per batch
174+
tracker.log(
175+
{'pred-img-mask': [wandb.Image(pred[0, 0, :, :]), wandb.Image(img[0, 0, :, :]),
176+
wandb.Image(mask[0, 0, :, :])]}
177+
)
145178

146179

147180
if __name__ == '__main__':

med_seg_diff_pytorch/dataset.py

Lines changed: 50 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,31 @@
11
import os
2-
os.environ['KMP_DUPLICATE_LIB_OK']='True'
2+
import numpy as np
3+
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
34
import torch
45
from torch.utils.data import Dataset
56
from PIL import Image
67
import pandas as pd
78
import random
89
import torchvision.transforms.functional as F
10+
11+
912
class ISICDataset(Dataset):
10-
def __init__(self, data_path, csv_file, img_folder, transform = None, training = True, flip_p=0.5):
13+
def __init__(self, data_path, csv_file, img_folder, transform=None, training=True, flip_p=0.5):
1114
df = pd.read_csv(os.path.join(data_path, csv_file), encoding='gbk')
1215
self.img_folder = img_folder
13-
self.name_list = df.iloc[:,0].tolist()
14-
self.label_list = df.iloc[:,1].tolist()
16+
self.name_list = df.iloc[:, 0].tolist()
17+
self.label_list = df.iloc[:, 1].tolist()
1518
self.data_path = data_path
1619
self.transform = transform
1720
self.training = training
1821
self.flip_p = flip_p
22+
1923
def __len__(self):
2024
return len(self.name_list)
25+
2126
def __getitem__(self, index):
2227
"""Get the images"""
23-
name = self.name_list[index]+'.jpg'
28+
name = self.name_list[index] + '.jpg'
2429
img_path = os.path.join(self.data_path, self.img_folder, name)
2530

2631
mask_name = name.split('.')[0] + '_Segmentation.png'
@@ -35,13 +40,52 @@ def __getitem__(self, index):
3540
label = int(self.label_list[index])
3641

3742
if self.transform:
43+
# save random state so that if more elaborate transforms are used
44+
# the same transform will be applied to both the mask and the img
45+
state = torch.get_rng_state()
3846
img = self.transform(img)
47+
torch.set_rng_state(state)
3948
mask = self.transform(mask)
4049
if random.random() < self.flip_p:
4150
img = F.vflip(img)
4251
mask = F.vflip(mask)
4352

44-
4553
if self.training:
4654
return (img, mask)
4755
return (img, mask, label)
56+
57+
58+
class GenericNpyDataset(torch.utils.data.Dataset):
59+
def __init__(self, directory: str, transform, test_flag: bool = True):
60+
'''
61+
Genereic dataset for loading npy files.
62+
The npy store 3D arrays with the first two dimensions being the image and the third dimension being the channels.
63+
channel 0 is the image and the other channel is the label.
64+
'''
65+
super().__init__()
66+
self.directory = os.path.expanduser(directory)
67+
self.transform = transform
68+
self.test_flag = test_flag
69+
self.filenames = [os.path.join(self.directory, x) for x in os.listdir(self.directory) if x.endswith('.npy')]
70+
71+
def __getitem__(self, x: int):
72+
fname = self.filenames[x]
73+
npy_img = np.load(fname)
74+
img = npy_img[:, :, :1]
75+
img = torch.from_numpy(img).permute(2, 0, 1)
76+
mask = npy_img[:, :, 1:]
77+
mask = np.where(mask > 0, 1, 0)
78+
image = img[:, ...]
79+
mask = torch.from_numpy(mask).permute(2, 0, 1).float()
80+
if self.transform:
81+
# save random state so that if more elaborate transforms are used
82+
# the same transform will be applied to both the mask and the img
83+
state = torch.get_rng_state()
84+
image = self.transform(image)
85+
torch.set_rng_state(state)
86+
mask = self.transform(mask)
87+
88+
return image, mask
89+
90+
def __len__(self) -> int:
91+
return len(self.filenames)

0 commit comments

Comments
 (0)