Skip to content

Commit 28c2252

Browse files
hsirmMMathisLab
authored andcommitted
hello, world
1 parent bb6ab36 commit 28c2252

File tree

7 files changed

+1891
-0
lines changed

7 files changed

+1891
-0
lines changed

AROS/Main.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
import torch
2+
import torch.nn as nn
3+
import torchvision
4+
import torchvision.transforms as transforms
5+
from torch.utils.data import DataLoader
6+
from sklearn.mixture import GaussianMixture
7+
import numpy as np
8+
from scipy.stats import multivariate_normal
9+
from sklearn.covariance import EmpiricalCovariance
10+
from robustbench.utils import load_model
11+
import torch.nn.functional as F
12+
from torch.utils.data import TensorDataset
13+
14+
15+
num_vclasses=100
16+
17+
18+
num_samples_needed=1
19+
20+
fast=True
21+
epoch1=1
22+
epoch2=1
23+
epoch3=1
24+
25+
26+
model_name_='Wang2023Better_WRN-70-16'
27+
in_dataset='cifar100'
28+
threat_model_='Linf'
29+
30+
31+
32+
# cifa10_models=['Ding2020MMA','Rebuffi2021Fixing_70_16_cutmix_extra'] 50000//num_classes
33+
34+
# cifar100_models=['Wang2023Better_WRN-70-16','Rice2020Overfitting']
35+
36+
37+
38+
trainloader,testloader,ID_OOD_loader=get_loaders(in_dataset=in_dataset)
39+
40+
41+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
42+
43+
44+
45+
46+
robust_backbone = load_model(model_name=model_name_, dataset=in_dataset, threat_model=threat_model_).to(device)
47+
last_layer_name, last_layer = list(robust_backbone.named_children())[-1]
48+
setattr(robust_backbone, last_layer_name, nn.Identity())
49+
50+
51+
52+
embeddings, labels = [], []
53+
54+
with torch.no_grad():
55+
for imgs, lbls in trainloader:
56+
imgs = imgs.to(device, non_blocking=True)
57+
embed = robust_backbone(imgs).cpu() # move to CPU only once per batch
58+
embeddings.append(embed)
59+
labels.append(lbls)
60+
embeddings = torch.cat(embeddings).numpy()
61+
labels = torch.cat(labels).numpy()
62+
63+
64+
print("embedding")
65+
66+
67+
if fast==False:
68+
gmm_dict = {}
69+
for cls in np.unique(labels):
70+
cls_embed = embeddings[labels == cls]
71+
gmm = GaussianMixture(n_components=1, covariance_type='full').fit(cls_embed)
72+
gmm_dict[cls] = gmm
73+
74+
print("fake start")
75+
76+
fake_data = []
77+
78+
79+
for cls, gmm in gmm_dict.items():
80+
samples, likelihoods = [], []
81+
while len(samples) < num_samples_needed:
82+
s = gmm.sample(100)[0]
83+
likelihood = gmm.score_samples(s)
84+
samples.append(s[likelihood < np.quantile(likelihood, 0.001)])
85+
likelihoods.append(likelihood[likelihood < np.quantile(likelihood, 0.001)])
86+
if sum(len(smp) for smp in samples) >= num_samples_needed:
87+
break
88+
samples = np.vstack(samples)[:num_samples_needed]
89+
fake_data.append(samples)
90+
91+
fake_data = np.vstack(fake_data)
92+
fake_data = torch.tensor(fake_data).float()
93+
fake_data = F.normalize(fake_data, p=2, dim=1)
94+
95+
fake_labels = torch.full((fake_data.shape[0],), 10)
96+
fake_loader = DataLoader(TensorDataset(fake_data, fake_labels), batch_size=128, shuffle=True)
97+
98+
if fast==True:
99+
100+
101+
noise_std = 0.1 # standard deviation of noise
102+
noisy_embeddings = torch.tensor(embeddings) + noise_std * torch.randn_like(torch.tensor(embeddings))
103+
104+
# Normalize Noisy Embeddings
105+
noisy_embeddings = F.normalize(noisy_embeddings, p=2, dim=1)[:len(trainloader.dataset)//num_classes]
106+
107+
# Convert to DataLoader if needed
108+
fake_labels = torch.full((noisy_embeddings.shape[0],), num_classes)[:len(trainloader.dataset)//num_classes]
109+
fake_loader = DataLoader(TensorDataset(noisy_embeddings, fake_labels), batch_size=128, shuffle=True)
110+
111+
112+
113+
final_model=stability_loss_function_(trainloader,testloader,robust_backbone,num_classes,fake_loader,last_layer)
114+
115+
116+
117+
attack_eps = 8/255
118+
attack_steps = 10
119+
attack_alpha = 2.5 * attack_eps / attack_steps
120+
test_attack = PGD_AUC(final_model, eps=attack_eps, steps=attack_steps, alpha=attack_alpha, num_classes=num_classes)
121+
122+
123+
124+
get_clean_AUC(final_model, ID_OOD_loader , device, num_classes)
125+
126+
adv_auc = get_auc_adversarial(model=final_model, test_loader=ID_OOD_loader, test_attack=test_attack, device=device, num_classes=num_classes)

AROS/data_loader.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import torch
2+
import torchvision
3+
from torch.utils.data import DataLoader, Dataset,Subset, SubsetRandomSampler, TensorDataset, ConcatDataset
4+
from torchvision import datasets, transforms, models
5+
from torchvision.datasets import CIFAR10, CIFAR100, MNIST, SVHN, FashionMNIST
6+
7+
8+
9+
class LabelChangedDataset(Dataset):
10+
def __init__(self, original_dataset, new_label):
11+
self.original_dataset = original_dataset
12+
self.new_label = new_label
13+
14+
def __len__(self):
15+
return len(self.original_dataset)
16+
17+
def __getitem__(self, idx):
18+
image, _ = self.original_dataset[idx]
19+
return image, self.new_label
20+
21+
22+
23+
def get_subsampled_subset(dataset, subset_ratio=0.1):
24+
subset_size = int(len(dataset) * subset_ratio)
25+
remaining_size = len(dataset) - subset_size
26+
27+
subset_indices, _ = torch.utils.data.random_split(dataset, [subset_size, remaining_size])
28+
subset_testset = Subset(dataset, subset_indices.indices)
29+
30+
return subset_testset
31+
32+
33+
34+
transform_tensor = transforms.Compose([ transforms.ToTensor()])
35+
36+
37+
trainset_CIFAR10 = torchvision.datasets.CIFAR10(
38+
root='./data', train=True, download=True, transform=transform_tensor)
39+
40+
41+
testset_CIFAR10 = torchvision.datasets.CIFAR10(
42+
root='./data', train=False, download=True, transform=transform_tensor)
43+
44+
45+
trainloader_CIFAR10 = DataLoader(trainset_CIFAR10, batch_size=64, shuffle=True, num_workers=2)
46+
47+
testloader_CIFAR10 = DataLoader(testset_CIFAR10, batch_size=16, shuffle=False, num_workers=2)
48+
49+
50+
51+
52+
trainset_CIFAR100 = torchvision.datasets.CIFAR100(
53+
root='./data', train=True, download=True, transform=transform_tensor)
54+
55+
testset_CIFAR100 = torchvision.datasets.CIFAR100(
56+
root='./data', train=False, download=True, transform=transform_tensor)
57+
58+
trainloader_CIFAR100 = DataLoader(trainset_CIFAR100, batch_size=64, shuffle=True, num_workers=2)
59+
60+
testloader_CIFAR100 = DataLoader(testset_CIFAR100, batch_size=16, shuffle=False, num_workers=2)
61+
62+
63+
testset_CIFAR10_relabled = LabelChangedDataset(testset_CIFAR10, new_label=100)
64+
testset_CIFAR100_relabled = LabelChangedDataset(testset_CIFAR100, new_label=10)
65+
66+
67+
testloader_CIFAR10_vs_CIFAR100 = DataLoader(ConcatDataset([testset_CIFAR10, testset_CIFAR100_relabled]), shuffle=False, batch_size=16)
68+
testloader_CIFAR100_vs_CIFAR10 = DataLoader(ConcatDataset([testset_CIFAR100, testset_CIFAR10_relabled]), shuffle=False, batch_size=16)
69+
70+
def get_loaders(in_dataset='CIFAR10'):
71+
if in_dataset == 'cifar10':
72+
return trainloader_CIFAR10, testloader_CIFAR10, testloader_CIFAR10_vs_CIFAR100
73+
else:
74+
raise ValueError(f"Dataset '{in_dataset}' is not supported.")

0 commit comments

Comments
 (0)