Skip to content

Commit 2f8edd1

Browse files
committed
throw in lion
1 parent 5b45dc9 commit 2f8edd1

File tree

2 files changed

+30
-13
lines changed

2 files changed

+30
-13
lines changed

driver.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,18 @@
1+
import os
12
import argparse
2-
import wandb
3+
from tqdm import tqdm
4+
35
import torch
6+
import torchvision.transforms as transforms
7+
8+
from torch.optim import AdamW
9+
from lion_pytorch import Lion
10+
411
from med_seg_diff_pytorch import Unet, MedSegDiff
512
from med_seg_diff_pytorch.dataset import ISICDataset
6-
import torchvision.transforms as transforms
7-
from tqdm import tqdm
13+
814
from accelerate import Accelerator
9-
import os
15+
import wandb
1016

1117
## Parse CLI arguments ##
1218
def parse_args():
@@ -25,6 +31,7 @@ def parse_args():
2531
parser.add_argument('-ab2', '--adam_beta2', type=float, default=0.999, help='The beta2 parameter for the Adam optimizer.')
2632
parser.add_argument('-aw', '--adam_weight_decay', type=float, default=1e-6, help='Weight decay magnitude for the Adam optimizer.')
2733
parser.add_argument('-ae', '--adam_epsilon', type=float, default=1e-08, help='Epsilon value for the Adam optimizer.')
34+
parser.add_argument('-ul', '--use_lion', type=bool, default=False, help='use Lion optimizer')
2835
parser.add_argument('-ic', '--mask_channels', type=int, default=1, help='input channels for training (default: 3)')
2936
parser.add_argument('-c', '--input_img_channels', type=int, default=3, help='output channels for training (default: 3)')
3037
parser.add_argument('-is', '--image_size', type=int, default=128, help='input image size (default: 128)')
@@ -86,15 +93,23 @@ def main():
8693
args.learning_rate = (
8794
args.learning_rate * args.gradient_accumulation_steps * args.batch_size * accelerator.num_processes
8895
)
89-
## Initialize optimizer
90-
optimizer = torch.optim.AdamW(
91-
model.parameters(),
92-
lr=args.learning_rate,
93-
betas=(args.adam_beta1, args.adam_beta2),
94-
weight_decay=args.adam_weight_decay,
95-
eps=args.adam_epsilon,
96-
)
9796

97+
## Initialize optimizer
98+
if not args.use_lion:
99+
optimizer = AdamW(
100+
model.parameters(),
101+
lr=args.learning_rate,
102+
betas=(args.adam_beta1, args.adam_beta2),
103+
weight_decay=args.adam_weight_decay,
104+
eps=args.adam_epsilon,
105+
)
106+
else:
107+
optimizer = Lion(
108+
model.parameters(),
109+
lr=args.learning_rate,
110+
betas=(args.adam_beta1, args.adam_beta2),
111+
weight_decay=args.adam_weight_decay
112+
)
98113

99114
## TRAIN MODEL ##
100115
running_loss = 0.0

setup.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'med-seg-diff-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.1.0',
6+
version = '0.1.1',
77
license='MIT',
88
description = 'MedSegDiff - SOTA medical image segmentation - Pytorch',
99
author = 'Phil Wang',
@@ -19,7 +19,9 @@
1919
install_requires=[
2020
'beartype',
2121
'einops',
22+
'lion-pytorch',
2223
'torch',
24+
'torchvision',
2325
'tqdm',
2426
'accelerate',
2527
'wandb'

0 commit comments

Comments
 (0)