11import os
22import argparse
33from tqdm import tqdm
4-
54import torch
5+ import numpy as np
66import torchvision .transforms as transforms
7-
87from torch .optim import AdamW
98from lion_pytorch import Lion
10-
119from med_seg_diff_pytorch import Unet , MedSegDiff
12- from med_seg_diff_pytorch .dataset import ISICDataset
13-
10+ from med_seg_diff_pytorch .dataset import ISICDataset , GenericNpyDataset
1411from accelerate import Accelerator
1512import wandb
1613
1714## Parse CLI arguments ##
1815def parse_args ():
1916 parser = argparse .ArgumentParser ()
2017 parser .add_argument ('-slr' , '--scale_lr' , action = 'store_true' , help = "Whether to scale lr." )
21- parser .add_argument ('-rt' , '--report_to' , type = str , default = "wandb" , choices = ["wandb" ], help = "Where to log to. Currently only supports wandb" )
18+ parser .add_argument ('-rt' , '--report_to' , type = str , default = "wandb" , choices = ["wandb" ],
19+ help = "Where to log to. Currently only supports wandb" )
2220 parser .add_argument ('-ld' , '--logging_dir' , type = str , default = "logs" , help = "Logging dir." )
2321 parser .add_argument ('-od' , '--output_dir' , type = str , default = "output" , help = "Output dir." )
24- parser .add_argument ('-mp' , '--mixed_precision' , type = str , default = "no" , choices = ["no" , "fp16" , "bf16" ], help = "Whether to do mixed precision" )
25- parser .add_argument ('-ga' , '--gradient_accumulation_steps' , type = int , default = 4 , help = "The number of gradient accumulation steps." )
26- parser .add_argument ('-img' , '--img_folder' , type = str , default = 'ISBI2016_ISIC_Part3B_Training_Data' , help = 'The image file path from data_path' )
27- parser .add_argument ('-csv' , '--csv_file' , type = str , default = 'ISBI2016_ISIC_Part3B_Training_GroundTruth.csv' , help = 'The csv file to load in from data_path' )
22+ parser .add_argument ('-mp' , '--mixed_precision' , type = str , default = "no" , choices = ["no" , "fp16" , "bf16" ],
23+ help = "Whether to do mixed precision" )
24+ parser .add_argument ('-ga' , '--gradient_accumulation_steps' , type = int , default = 4 ,
25+ help = "The number of gradient accumulation steps." )
26+ parser .add_argument ('-img' , '--img_folder' , type = str , default = 'ISBI2016_ISIC_Part3B_Training_Data' ,
27+ help = 'The image file path from data_path' )
28+ parser .add_argument ('-csv' , '--csv_file' , type = str , default = 'ISBI2016_ISIC_Part3B_Training_GroundTruth.csv' ,
29+ help = 'The csv file to load in from data_path' )
2830 parser .add_argument ('-sc' , '--self_condition' , action = 'store_true' , help = 'Whether to do self condition' )
2931 parser .add_argument ('-lr' , '--learning_rate' , type = float , default = 5e-4 , help = 'learning rate' )
30- parser .add_argument ('-ab1' , '--adam_beta1' , type = float , default = 0.95 , help = 'The beta1 parameter for the Adam optimizer.' )
31- parser .add_argument ('-ab2' , '--adam_beta2' , type = float , default = 0.999 , help = 'The beta2 parameter for the Adam optimizer.' )
32- parser .add_argument ('-aw' , '--adam_weight_decay' , type = float , default = 1e-6 , help = 'Weight decay magnitude for the Adam optimizer.' )
33- parser .add_argument ('-ae' , '--adam_epsilon' , type = float , default = 1e-08 , help = 'Epsilon value for the Adam optimizer.' )
32+ parser .add_argument ('-ab1' , '--adam_beta1' , type = float , default = 0.95 ,
33+ help = 'The beta1 parameter for the Adam optimizer.' )
34+ parser .add_argument ('-ab2' , '--adam_beta2' , type = float , default = 0.999 ,
35+ help = 'The beta2 parameter for the Adam optimizer.' )
36+ parser .add_argument ('-aw' , '--adam_weight_decay' , type = float , default = 1e-6 ,
37+ help = 'Weight decay magnitude for the Adam optimizer.' )
38+ parser .add_argument ('-ae' , '--adam_epsilon' , type = float , default = 1e-08 ,
39+ help = 'Epsilon value for the Adam optimizer.' )
3440 parser .add_argument ('-ul' , '--use_lion' , type = bool , default = False , help = 'use Lion optimizer' )
3541 parser .add_argument ('-ic' , '--mask_channels' , type = int , default = 1 , help = 'input channels for training (default: 3)' )
36- parser .add_argument ('-c' , '--input_img_channels' , type = int , default = 3 , help = 'output channels for training (default: 3)' )
42+ parser .add_argument ('-c' , '--input_img_channels' , type = int , default = 3 ,
43+ help = 'output channels for training (default: 3)' )
3744 parser .add_argument ('-is' , '--image_size' , type = int , default = 128 , help = 'input image size (default: 128)' )
3845 parser .add_argument ('-dd' , '--data_path' , default = './data' , help = 'directory of input image' )
39- parser .add_argument ('-d' , '--dim' , type = int , default = 64 , help = 'dim (deaault : 64)' )
40- parser .add_argument ('-e' , '--epochs' , type = int , default = 10 , help = 'number of epochs (default: 128 )' )
46+ parser .add_argument ('-d' , '--dim' , type = int , default = 64 , help = 'dim (default : 64)' )
47+ parser .add_argument ('-e' , '--epochs' , type = int , default = 10000 , help = 'number of epochs (default: 10000 )' )
4148 parser .add_argument ('-bs' , '--batch_size' , type = int , default = 8 , help = 'batch size to train on (default: 8)' )
42- parser .add_argument ('-ds' , '--dataset' , default = 'ISIC' , help = 'Dataset to use' )
49+ parser .add_argument ('--timesteps' , type = int , default = 1000 , help = 'number of timesteps (default: 1000)' )
50+ parser .add_argument ('-ds' , '--dataset' , default = 'generic' , help = 'Dataset to use' )
51+ parser .add_argument ('--save_every' , type = int , default = 100 , help = 'save_every n rpochs (default: 100)' )
52+ parser .add_argument ('--load_model_from' , default = None , help = 'path to pt file to load from' )
4353 return parser .parse_args ()
4454
4555
4656def load_data (args ):
47- # Create transforms for data
48- transform_list = [transforms .Resize ((args .image_size ,args .image_size )), transforms .ToTensor (),]
49- transform_train = transforms .Compose (transform_list )
50-
5157 # Load dataset
5258 if args .dataset == 'ISIC' :
53- dataset = ISICDataset (args .data_path , args .csv_file , args .img_folder , transform = transform_train , training = True , flip_p = 0.5 )
59+ transform_list = [transforms .Resize ((args .image_size , args .image_size )), transforms .ToTensor (), ]
60+ transform_train = transforms .Compose (transform_list )
61+ dataset = ISICDataset (args .data_path , args .csv_file , args .img_folder , transform = transform_train , training = True ,
62+ flip_p = 0.5 )
63+ elif args .dataset == 'generic' :
64+ transform_list = [transforms .ToPILImage (), transforms .Resize (args .image_size ), transforms .ToTensor ()]
65+ transform_train = transforms .Compose (transform_list )
66+ dataset = GenericNpyDataset (args .data_path , transform = transform_train , test_flag = False )
5467 else :
5568 raise NotImplementedError (f"Your dataset { args .dataset } hasn't been implemented yet." )
5669
@@ -63,10 +76,11 @@ def load_data(args):
6376 return training_generator
6477
6578
66-
6779def main ():
6880 args = parse_args ()
81+ checkpoint_dir = os .path .join (args .output_dir , 'checkpoints' )
6982 logging_dir = os .path .join (args .output_dir , args .logging_dir )
83+ os .makedirs (checkpoint_dir , exist_ok = True )
7084 accelerator = Accelerator (
7185 gradient_accumulation_steps = args .gradient_accumulation_steps ,
7286 mixed_precision = args .mixed_precision ,
@@ -78,20 +92,20 @@ def main():
7892
7993 ## DEFINE MODEL ##
8094 model = Unet (
81- dim = args .dim ,
82- image_size = args .image_size ,
83- dim_mults = (1 , 2 , 4 , 8 ),
84- mask_channels = args .mask_channels ,
85- input_img_channels = args .input_img_channels ,
86- self_condition = args .self_condition
95+ dim = args .dim ,
96+ image_size = args .image_size ,
97+ dim_mults = (1 , 2 , 4 , 8 ),
98+ mask_channels = args .mask_channels ,
99+ input_img_channels = args .input_img_channels ,
100+ self_condition = args .self_condition
87101 )
88102
89103 ## LOAD DATA ##
90104 data_loader = load_data (args )
91- #training_generator = tqdm(data_loader, total=int(len(data_loader)))
105+ # training_generator = tqdm(data_loader, total=int(len(data_loader)))
92106 if args .scale_lr :
93107 args .learning_rate = (
94- args .learning_rate * args .gradient_accumulation_steps * args .batch_size * accelerator .num_processes
108+ args .learning_rate * args .gradient_accumulation_steps * args .batch_size * accelerator .num_processes
95109 )
96110
97111 ## Initialize optimizer
@@ -112,22 +126,29 @@ def main():
112126 )
113127
114128 ## TRAIN MODEL ##
115- running_loss = 0.0
116129 counter = 0
117130 model , optimizer , data_loader = accelerator .prepare (
118131 model , optimizer , data_loader
119132 )
120133 diffusion = MedSegDiff (
121134 model ,
122- timesteps = args .epochs
135+ timesteps = args .timesteps
123136 ).to (accelerator .device )
137+
138+ if args .load_model_from is not None :
139+ save_dict = torch .load (args .load_model_from )
140+ diffusion .model .load_state_dict (save_dict ['model_state_dict' ])
141+ optimizer .load_state_dict (save_dict ['optimizer_state_dict' ])
142+ accelerator .print (f'Loaded from { args .load_model_from } ' )
143+
124144 ## Iterate across training loop
125145 for epoch in range (args .epochs ):
126- print ('Epoch {}/{}' .format (epoch + 1 , args .epochs ))
146+ running_loss = 0.0
147+ print ('Epoch {}/{}' .format (epoch + 1 , args .epochs ))
127148 for (img , mask ) in tqdm (data_loader ):
128149 with accelerator .accumulate (model ):
129150 loss = diffusion (mask , img )
130- accelerator .log ({'loss' : loss }) # Log loss to wandb
151+ accelerator .log ({'loss' : loss }) # Log loss to wandb
131152 accelerator .backward (loss )
132153 optimizer .step ()
133154 optimizer .zero_grad ()
@@ -136,12 +157,24 @@ def main():
136157 epoch_loss = running_loss / len (data_loader )
137158 print ('Training Loss : {:.4f}' .format (epoch_loss ))
138159 ## INFERENCE ##
139- pred = diffusion .sample (img ).cpu ().detach ().numpy ()
140- for tracker in accelerator .trackers :
141- if tracker .name == "wandb" :
142- tracker .log (
143- {'pred-img-mask' : [wandb .Image (pred ), wandb .Image (img ), wandb .Image (mask )]}
144- )
160+
161+ if epoch % args .save_every == 0 :
162+ torch .save ({
163+ 'epoch' : epoch ,
164+ 'model_state_dict' : diffusion .model .state_dict (),
165+ 'optimizer_state_dict' : optimizer .state_dict (),
166+ 'loss' : loss ,
167+ }, os .path .join (checkpoint_dir , f'state_dict_epoch_{ epoch } _loss_{ epoch_loss } .pt' ))
168+
169+ pred = diffusion .sample (img ).cpu ().detach ().numpy ()
170+
171+ for tracker in accelerator .trackers :
172+ if tracker .name == "wandb" :
173+ # save just one image per batch
174+ tracker .log (
175+ {'pred-img-mask' : [wandb .Image (pred [0 , 0 , :, :]), wandb .Image (img [0 , 0 , :, :]),
176+ wandb .Image (mask [0 , 0 , :, :])]}
177+ )
145178
146179
147180if __name__ == '__main__' :
0 commit comments