1+ import os
12import argparse
2- import wandb
3+ from tqdm import tqdm
4+
35import torch
6+ import torchvision .transforms as transforms
7+
8+ from torch .optim import AdamW
9+ from lion_pytorch import Lion
10+
411from med_seg_diff_pytorch import Unet , MedSegDiff
512from med_seg_diff_pytorch .dataset import ISICDataset
6- import torchvision .transforms as transforms
7- from tqdm import tqdm
13+
814from accelerate import Accelerator
9- import os
15+ import wandb
1016
1117## Parse CLI arguments ##
1218def parse_args ():
@@ -25,6 +31,7 @@ def parse_args():
2531 parser .add_argument ('-ab2' , '--adam_beta2' , type = float , default = 0.999 , help = 'The beta2 parameter for the Adam optimizer.' )
2632 parser .add_argument ('-aw' , '--adam_weight_decay' , type = float , default = 1e-6 , help = 'Weight decay magnitude for the Adam optimizer.' )
2733 parser .add_argument ('-ae' , '--adam_epsilon' , type = float , default = 1e-08 , help = 'Epsilon value for the Adam optimizer.' )
34+ parser .add_argument ('-ul' , '--use_lion' , type = bool , default = False , help = 'use Lion optimizer' )
2835 parser .add_argument ('-ic' , '--mask_channels' , type = int , default = 1 , help = 'input channels for training (default: 3)' )
2936 parser .add_argument ('-c' , '--input_img_channels' , type = int , default = 3 , help = 'output channels for training (default: 3)' )
3037 parser .add_argument ('-is' , '--image_size' , type = int , default = 128 , help = 'input image size (default: 128)' )
@@ -86,15 +93,23 @@ def main():
8693 args .learning_rate = (
8794 args .learning_rate * args .gradient_accumulation_steps * args .batch_size * accelerator .num_processes
8895 )
89- ## Initialize optimizer
90- optimizer = torch .optim .AdamW (
91- model .parameters (),
92- lr = args .learning_rate ,
93- betas = (args .adam_beta1 , args .adam_beta2 ),
94- weight_decay = args .adam_weight_decay ,
95- eps = args .adam_epsilon ,
96- )
9796
97+ ## Initialize optimizer
98+ if not args .use_lion :
99+ optimizer = AdamW (
100+ model .parameters (),
101+ lr = args .learning_rate ,
102+ betas = (args .adam_beta1 , args .adam_beta2 ),
103+ weight_decay = args .adam_weight_decay ,
104+ eps = args .adam_epsilon ,
105+ )
106+ else :
107+ optimizer = Lion (
108+ model .parameters (),
109+ lr = args .learning_rate ,
110+ betas = (args .adam_beta1 , args .adam_beta2 ),
111+ weight_decay = args .adam_weight_decay
112+ )
98113
99114 ## TRAIN MODEL ##
100115 running_loss = 0.0
0 commit comments