Skip to content

Commit 2e4a691

Browse files
authored
Update 11_optimizing_pauc_loss_on_imbalanced_data_wrapper.py
1 parent 6912158 commit 2e4a691

File tree

1 file changed

+2
-5
lines changed

1 file changed

+2
-5
lines changed

examples/scripts/11_optimizing_pauc_loss_on_imbalanced_data_wrapper.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,6 @@ def __getitem__(self, idx):
107107

108108
# sampling parameters
109109
sampling_rate = 0.5
110-
num_pos = int(batch_size*sampling_rate)
111-
num_neg = int(batch_size*(1-sampling_rate))
112-
113110

114111
train_data, train_targets = CIFAR10(root='./data', train=True)
115112
test_data, test_targets = CIFAR10(root='./data', train=False)
@@ -137,7 +134,7 @@ def __getitem__(self, idx):
137134
model = resnet18(pretrained=False, num_classes=1, last_activation=None)
138135
model = model.cuda()
139136

140-
loss_fn = pAUCLoss(pos_len=sampler.pos_len, backend='SOPA', beta=beta, num_neg=num_neg, margin=margin)
137+
loss_fn = pAUCLoss(pos_len=sampler.pos_len, backend='SOPA', beta=beta, margin=margin)
141138
optimizer = SOPA(model.parameters(), loss_fn=loss_fn.loss_fn, mode='adam', lr=lr, eta=eta, weight_decay=weight_decay)
142139

143140

@@ -202,4 +199,4 @@ def __getitem__(self, idx):
202199
plt.title('CIFAR-10 (20% imbalanced)',fontsize=30)
203200
plt.legend(fontsize=15)
204201
plt.ylabel('OPAUC(0.3)',fontsize=25)
205-
plt.xlabel('epochs',fontsize=25)
202+
plt.xlabel('epochs',fontsize=25)

0 commit comments

Comments
 (0)