@@ -107,9 +107,6 @@ def __getitem__(self, idx):
107107
108108# sampling parameters
109109sampling_rate = 0.5
110- num_pos = int (batch_size * sampling_rate )
111- num_neg = int (batch_size * (1 - sampling_rate ))
112-
113110
114111train_data , train_targets = CIFAR10 (root = './data' , train = True )
115112test_data , test_targets = CIFAR10 (root = './data' , train = False )
@@ -137,7 +134,7 @@ def __getitem__(self, idx):
137134model = resnet18 (pretrained = False , num_classes = 1 , last_activation = None )
138135model = model .cuda ()
139136
140- loss_fn = pAUCLoss (pos_len = sampler .pos_len , backend = 'SOPA' , beta = beta , num_neg = num_neg , margin = margin )
137+ loss_fn = pAUCLoss (pos_len = sampler .pos_len , backend = 'SOPA' , beta = beta , margin = margin )
141138optimizer = SOPA (model .parameters (), loss_fn = loss_fn .loss_fn , mode = 'adam' , lr = lr , eta = eta , weight_decay = weight_decay )
142139
143140
@@ -202,4 +199,4 @@ def __getitem__(self, idx):
202199plt .title ('CIFAR-10 (20% imbalanced)' ,fontsize = 30 )
203200plt .legend (fontsize = 15 )
204201plt .ylabel ('OPAUC(0.3)' ,fontsize = 25 )
205- plt .xlabel ('epochs' ,fontsize = 25 )
202+ plt .xlabel ('epochs' ,fontsize = 25 )
0 commit comments