Skip to content

Commit 2862a24

Browse files
authored
Add files via upload
1 parent 9153484 commit 2862a24

File tree

1 file changed

+160
-0
lines changed

1 file changed

+160
-0
lines changed
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
"""
2+
Author: Zhuoning Yuan
3+
4+
"""
5+
6+
"""
7+
# **Importing LibAUC**"""
8+
9+
from libauc.losses import AUCMLoss
10+
from libauc.optimizers import PESG
11+
from libauc.models import DenseNet121, DenseNet169
12+
from libauc.datasets import Melanoma
13+
from libauc.utils import auroc
14+
15+
import torch
16+
from PIL import Image
17+
import numpy as np
18+
import torchvision.transforms as transforms
19+
from torch.utils.data import Dataset
20+
21+
"""# **Reproducibility**"""
22+
23+
def set_all_seeds(SEED):
24+
# REPRODUCIBILITY
25+
torch.manual_seed(SEED)
26+
np.random.seed(SEED)
27+
torch.backends.cudnn.deterministic = True
28+
torch.backends.cudnn.benchmark = False
29+
30+
"""# **Data Augmentation**"""
31+
32+
import albumentations as A
33+
from albumentations.pytorch.transforms import ToTensor
34+
35+
def augmentations(image_size=256, is_test=True):
36+
# https://www.kaggle.com/vishnus/a-simple-pytorch-starter-code-single-fold-93
37+
imagenet_stats = {'mean':[0.485, 0.456, 0.406], 'std':[0.229, 0.224, 0.225]}
38+
train_tfms = A.Compose([
39+
A.Cutout(p=0.5),
40+
A.RandomRotate90(p=0.5),
41+
A.Flip(p=0.5),
42+
A.OneOf([
43+
A.RandomBrightnessContrast(brightness_limit=0.2,
44+
contrast_limit=0.2,
45+
),
46+
A.HueSaturationValue(
47+
hue_shift_limit=20,
48+
sat_shift_limit=50,
49+
val_shift_limit=50)
50+
], p=0.5),
51+
A.OneOf([
52+
A.IAAAdditiveGaussianNoise(),
53+
A.GaussNoise(),
54+
], p=0.5),
55+
A.OneOf([
56+
A.MotionBlur(p=0.2),
57+
A.MedianBlur(blur_limit=3, p=0.1),
58+
A.Blur(blur_limit=3, p=0.1),
59+
], p=0.5),
60+
A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.2, rotate_limit=45, p=0.5),
61+
A.OneOf([
62+
A.OpticalDistortion(p=0.3),
63+
A.GridDistortion(p=0.1),
64+
A.IAAPiecewiseAffine(p=0.3),
65+
], p=0.5),
66+
ToTensor(normalize=imagenet_stats)
67+
])
68+
69+
test_tfms = A.Compose([ToTensor(normalize=imagenet_stats)])
70+
if is_test:
71+
return test_tfms
72+
else:
73+
return train_tfms
74+
75+
"""# **Optimizing AUCM Loss**
76+
* Installation of `albumentations` is required!
77+
"""
78+
79+
# dataset
80+
trainSet = Melanoma(root='./melanoma/', is_test=False, test_size=0.2, transforms=augmentations)
81+
testSet = Melanoma(root='./melanoma/', is_test=True, test_size=0.2, transforms=augmentations)
82+
83+
# paramaters
84+
SEED = 123
85+
BATCH_SIZE = 64
86+
lr = 0.1
87+
gamma = 500
88+
imratio = trainSet.imratio
89+
weight_decay = 1e-5
90+
margin = 1.0
91+
92+
# model
93+
set_all_seeds(SEED)
94+
model = DenseNet121(pretrained=True, last_activation=None, activations='relu', num_classes=1)
95+
model = model.cuda()
96+
97+
trainloader = torch.utils.data.DataLoader(trainSet, batch_size=BATCH_SIZE, num_workers=2, shuffle=True)
98+
testloader = torch.utils.data.DataLoader(testSet, batch_size=BATCH_SIZE, num_workers=2, shuffle=False)
99+
100+
# load your own pretrained model here
101+
# PATH = 'ce_pretrained_model.pth'
102+
# state_dict = torch.load(PATH)
103+
# state_dict.pop('classifier.weight', None)
104+
# state_dict.pop('classifier.bias', None)
105+
# model.load_state_dict(state_dict, strict=False)
106+
107+
# define loss & optimizer
108+
Loss = AUCMLoss(imratio=imratio)
109+
optimizer = PESG(model,
110+
a=Loss.a,
111+
b=Loss.b,
112+
alpha=Loss.alpha,
113+
lr=lr,
114+
gamma=gamma,
115+
margin=margin,
116+
weight_decay=weight_decay)
117+
118+
total_epochs = 16
119+
best_val_auc = 0
120+
for epoch in range(total_epochs):
121+
122+
# reset stages
123+
if epoch== int(total_epochs*0.5) or epoch== int(total_epochs*0.75):
124+
optimizer.update_regularizer(decay_factor=10)
125+
126+
# training
127+
for idx, data in enumerate(trainloader):
128+
train_data, train_labels = data
129+
train_data, train_labels = train_data.cuda(), train_labels.cuda()
130+
y_pred = model(train_data)
131+
y_pred = torch.sigmoid(y_pred)
132+
loss = Loss(y_pred, train_labels)
133+
optimizer.zero_grad()
134+
loss.backward()
135+
optimizer.step()
136+
137+
# validation
138+
model.eval()
139+
with torch.no_grad():
140+
test_pred = []
141+
test_true = []
142+
for jdx, data in enumerate(testloader):
143+
test_data, test_label = data
144+
test_data = test_data.cuda()
145+
y_pred = model(test_data)
146+
y_pred = torch.sigmoid(y_pred)
147+
test_pred.append(y_pred.cpu().detach().numpy())
148+
test_true.append(test_label.numpy())
149+
150+
test_true = np.concatenate(test_true)
151+
test_pred = np.concatenate(test_pred)
152+
val_auc = auroc(test_true, test_pred)
153+
model.train()
154+
155+
if best_val_auc < val_auc:
156+
best_val_auc = val_auc
157+
158+
print ('Epoch=%s, Loss=%.4f, Val_AUC=%.4f, lr=%.4f'%(epoch, loss, val_auc, optimizer.lr))
159+
160+
print ('Best Val_AUC is %.4f'%best_val_auc)

0 commit comments

Comments
 (0)