Skip to content

Commit 3e8ac20

Browse files
authored
Update evaluate.py
1 parent 90fce43 commit 3e8ac20

File tree

1 file changed

+22
-5
lines changed

1 file changed

+22
-5
lines changed

AROS/evaluate.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,24 @@
11
import numpy as np
2-
from tqdm import tqdm
2+
from tqdm.notebook import tqdm
33
from sklearn.metrics import roc_curve, roc_auc_score, precision_recall_curve, auc
44
from utils import *
5+
import argparse
6+
import torch
7+
import torch.nn as nn
8+
import torchvision
9+
import torchvision.transforms as transforms
10+
from torch.utils.data import DataLoader
11+
from sklearn.mixture import GaussianMixture
12+
import numpy as np
13+
from scipy.stats import multivariate_normal
14+
from sklearn.covariance import EmpiricalCovariance
15+
from robustbench.utils import load_model
16+
import torch.nn.functional as F
17+
from torch.utils.data import TensorDataset
18+
from torch.optim.lr_scheduler import StepLR
19+
from tqdm.notebook import tqdm
20+
21+
522
def compute_fpr95(labels, scores):
623

724
fpr, tpr, thresholds = roc_curve(labels, scores)
@@ -60,9 +77,9 @@ def get_clean_AUC(model, test_loader , device, num_classes):
6077
auroc = compute_auroc(test_labels, anomaly_scores)
6178
aupr = compute_aupr(test_labels, anomaly_scores)
6279

63-
# print(f"FPR95: {fpr95}")
80+
print(f"FPR95: {fpr95}")
6481
print(f"AUROC is: {auroc}")
65-
# print(f"AUPR: {aupr}")
82+
print(f"AUPR: {aupr}")
6683

6784
return auc
6885

@@ -598,9 +615,9 @@ def get_auc_adversarial(model, test_loader, test_attack, device, num_classes):
598615
auroc = compute_auroc(test_labels, anomaly_scores)
599616
aupr = compute_aupr(test_labels, anomaly_scores)
600617

601-
# print(f"FPR95: {fpr95}")
618+
print(f"FPR95: {fpr95}")
602619
print(f"AUROC is: {auroc}")
603-
# print(f"AUPR: {aupr}")
620+
print(f"AUPR: {aupr}")
604621

605622

606623
if is_train:

0 commit comments

Comments
 (0)