1+ """
2+ Author: Zhuoning Yuan
3+ 4+ """
5+
6+ """
7+ # **Importing LibAUC**"""
8+
9+ from libauc .losses import AUCMLoss
10+ from libauc .optimizers import PESG
11+ from libauc .models import DenseNet121 , DenseNet169
12+ from libauc .datasets import Melanoma
13+ from libauc .utils import auroc
14+
15+ import torch
16+ from PIL import Image
17+ import numpy as np
18+ import torchvision .transforms as transforms
19+ from torch .utils .data import Dataset
20+
21+ """# **Reproducibility**"""
22+
23+ def set_all_seeds (SEED ):
24+ # REPRODUCIBILITY
25+ torch .manual_seed (SEED )
26+ np .random .seed (SEED )
27+ torch .backends .cudnn .deterministic = True
28+ torch .backends .cudnn .benchmark = False
29+
30+ """# **Data Augmentation**"""
31+
32+ import albumentations as A
33+ from albumentations .pytorch .transforms import ToTensor
34+
35+ def augmentations (image_size = 256 , is_test = True ):
36+ # https://www.kaggle.com/vishnus/a-simple-pytorch-starter-code-single-fold-93
37+ imagenet_stats = {'mean' :[0.485 , 0.456 , 0.406 ], 'std' :[0.229 , 0.224 , 0.225 ]}
38+ train_tfms = A .Compose ([
39+ A .Cutout (p = 0.5 ),
40+ A .RandomRotate90 (p = 0.5 ),
41+ A .Flip (p = 0.5 ),
42+ A .OneOf ([
43+ A .RandomBrightnessContrast (brightness_limit = 0.2 ,
44+ contrast_limit = 0.2 ,
45+ ),
46+ A .HueSaturationValue (
47+ hue_shift_limit = 20 ,
48+ sat_shift_limit = 50 ,
49+ val_shift_limit = 50 )
50+ ], p = 0.5 ),
51+ A .OneOf ([
52+ A .IAAAdditiveGaussianNoise (),
53+ A .GaussNoise (),
54+ ], p = 0.5 ),
55+ A .OneOf ([
56+ A .MotionBlur (p = 0.2 ),
57+ A .MedianBlur (blur_limit = 3 , p = 0.1 ),
58+ A .Blur (blur_limit = 3 , p = 0.1 ),
59+ ], p = 0.5 ),
60+ A .ShiftScaleRotate (shift_limit = 0.0625 , scale_limit = 0.2 , rotate_limit = 45 , p = 0.5 ),
61+ A .OneOf ([
62+ A .OpticalDistortion (p = 0.3 ),
63+ A .GridDistortion (p = 0.1 ),
64+ A .IAAPiecewiseAffine (p = 0.3 ),
65+ ], p = 0.5 ),
66+ ToTensor (normalize = imagenet_stats )
67+ ])
68+
69+ test_tfms = A .Compose ([ToTensor (normalize = imagenet_stats )])
70+ if is_test :
71+ return test_tfms
72+ else :
73+ return train_tfms
74+
75+ """# **Optimizing AUCM Loss**
76+ * Installation of `albumentations` is required!
77+ """
78+
79+ # dataset
80+ trainSet = Melanoma (root = './melanoma/' , is_test = False , test_size = 0.2 , transforms = augmentations )
81+ testSet = Melanoma (root = './melanoma/' , is_test = True , test_size = 0.2 , transforms = augmentations )
82+
83+ # paramaters
84+ SEED = 123
85+ BATCH_SIZE = 64
86+ lr = 0.1
87+ gamma = 500
88+ imratio = trainSet .imratio
89+ weight_decay = 1e-5
90+ margin = 1.0
91+
92+ # model
93+ set_all_seeds (SEED )
94+ model = DenseNet121 (pretrained = True , last_activation = None , activations = 'relu' , num_classes = 1 )
95+ model = model .cuda ()
96+
97+ trainloader = torch .utils .data .DataLoader (trainSet , batch_size = BATCH_SIZE , num_workers = 2 , shuffle = True )
98+ testloader = torch .utils .data .DataLoader (testSet , batch_size = BATCH_SIZE , num_workers = 2 , shuffle = False )
99+
100+ # load your own pretrained model here
101+ # PATH = 'ce_pretrained_model.pth'
102+ # state_dict = torch.load(PATH)
103+ # state_dict.pop('classifier.weight', None)
104+ # state_dict.pop('classifier.bias', None)
105+ # model.load_state_dict(state_dict, strict=False)
106+
107+ # define loss & optimizer
108+ Loss = AUCMLoss (imratio = imratio )
109+ optimizer = PESG (model ,
110+ a = Loss .a ,
111+ b = Loss .b ,
112+ alpha = Loss .alpha ,
113+ lr = lr ,
114+ gamma = gamma ,
115+ margin = margin ,
116+ weight_decay = weight_decay )
117+
118+ total_epochs = 16
119+ best_val_auc = 0
120+ for epoch in range (total_epochs ):
121+
122+ # reset stages
123+ if epoch == int (total_epochs * 0.5 ) or epoch == int (total_epochs * 0.75 ):
124+ optimizer .update_regularizer (decay_factor = 10 )
125+
126+ # training
127+ for idx , data in enumerate (trainloader ):
128+ train_data , train_labels = data
129+ train_data , train_labels = train_data .cuda (), train_labels .cuda ()
130+ y_pred = model (train_data )
131+ y_pred = torch .sigmoid (y_pred )
132+ loss = Loss (y_pred , train_labels )
133+ optimizer .zero_grad ()
134+ loss .backward ()
135+ optimizer .step ()
136+
137+ # validation
138+ model .eval ()
139+ with torch .no_grad ():
140+ test_pred = []
141+ test_true = []
142+ for jdx , data in enumerate (testloader ):
143+ test_data , test_label = data
144+ test_data = test_data .cuda ()
145+ y_pred = model (test_data )
146+ y_pred = torch .sigmoid (y_pred )
147+ test_pred .append (y_pred .cpu ().detach ().numpy ())
148+ test_true .append (test_label .numpy ())
149+
150+ test_true = np .concatenate (test_true )
151+ test_pred = np .concatenate (test_pred )
152+ val_auc = auroc (test_true , test_pred )
153+ model .train ()
154+
155+ if best_val_auc < val_auc :
156+ best_val_auc = val_auc
157+
158+ print ('Epoch=%s, Loss=%.4f, Val_AUC=%.4f, lr=%.4f' % (epoch , loss , val_auc , optimizer .lr ))
159+
160+ print ('Best Val_AUC is %.4f' % best_val_auc )
0 commit comments