Skip to content

Commit a2d8c1d

Browse files
authored
Update Main.py
1 parent c1dd996 commit a2d8c1d

File tree

1 file changed

+87
-85
lines changed

1 file changed

+87
-85
lines changed

AROS/Main.py

Lines changed: 87 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -1,126 +1,128 @@
1+
2+
!pip install -r requirements.txt
3+
import argparse
14
import torch
25
import torch.nn as nn
3-
import torchvision
4-
import torchvision.transforms as transforms
5-
from torch.utils.data import DataLoader
6-
from sklearn.mixture import GaussianMixture
7-
import numpy as np
8-
from scipy.stats import multivariate_normal
9-
from sklearn.covariance import EmpiricalCovariance
10-
from robustbench.utils import load_model
11-
import torch.nn.functional as F
12-
from torch.utils.data import TensorDataset
13-
14-
15-
num_vclasses=100
16-
17-
18-
num_samples_needed=1
19-
20-
fast=True
21-
epoch1=1
22-
epoch2=1
23-
epoch3=1
6+
from evaluate import *
7+
from utils import *
8+
from tqdm.notebook import tqdm
9+
from data_loader import *
10+
from stability_loss_function import *
2411

12+
def main():
13+
parser = argparse.ArgumentParser(description="Hyperparameters for the script")
2514

26-
model_name_='Wang2023Better_WRN-70-16'
27-
in_dataset='cifar100'
28-
threat_model_='Linf'
29-
30-
31-
32-
# cifa10_models=['Ding2020MMA','Rebuffi2021Fixing_70_16_cutmix_extra'] 50000//num_classes
15+
# Define the hyperparameters controlled via CLI 'Ding2020MMA'
16+
parser.add_argument('--fast', type=bool, default=True, help='Toggle between fast and full fake data generation modes')
17+
parser.add_argument('--epoch1', type=int, default=2, help='Number of epochs for stage 1')
18+
parser.add_argument('--epoch2', type=int, default=1, help='Number of epochs for stage 2')
19+
parser.add_argument('--epoch3', type=int, default=2, help='Number of epochs for stage 3')
20+
parser.add_argument('--in_dataset', type=str, default='cifar10', choices=['cifar10', 'cifar100'], help='The in-distribution dataset to be used')
21+
parser.add_argument('--threat_model', type=str, default='Linf', help='Adversarial threat model for robust training')
22+
parser.add_argument('--noise_std', type=float, default=1, help='Standard deviation of noise for generating noisy fake embeddings')
23+
parser.add_argument('--attack_eps', type=float, default=8/255, help='Perturbation bound (epsilon) for PGD attack')
24+
parser.add_argument('--attack_steps', type=int, default=10, help='Number of steps for the PGD attack')
25+
parser.add_argument('--attack_alpha', type=float, default=2.5 * (8/255) / 10, help='Step size (alpha) for each PGD attack iteration')
3326

34-
# cifar100_models=['Wang2023Better_WRN-70-16','Rice2020Overfitting']
27+
args = parser.parse_args('')
3528

29+
# Set the default model name based on the selected dataset
30+
if args.in_dataset == 'cifar10':
31+
default_model_name = 'Rebuffi2021Fixing_70_16_cutmix_extra'
32+
elif args.in_dataset == 'cifar100':
33+
default_model_name = 'Wang2023Better_WRN-70-16'
3634

35+
parser.add_argument('--model_name', type=str, default=default_model_name, choices=['Rebuffi2021Fixing_70_16_cutmix_extra', 'Wang2023Better_WRN-70-16'], help='The pre-trained model to be used for feature extraction')
3736

38-
trainloader,testloader,ID_OOD_loader=get_loaders(in_dataset=in_dataset)
37+
# Re-parse arguments to include model_name selection based on the dataset
38+
args = parser.parse_args('')
39+
num_classes = 10 if args.in_dataset == 'cifar10' else 100
3940

41+
trainloader, testloader,test_set, ID_OOD_loader = get_loaders(in_dataset=args.in_dataset)
4042

41-
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
43+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
4244

4345

46+
robust_backbone = load_model(model_name=args.model_name, dataset=args.in_dataset, threat_model=args.threat_model).to(device)
47+
last_layer_name, last_layer = list(robust_backbone.named_children())[-1]
48+
setattr(robust_backbone, last_layer_name, nn.Identity())
49+
fake_loader=None
4450

4551

46-
robust_backbone = load_model(model_name=model_name_, dataset=in_dataset, threat_model=threat_model_).to(device)
47-
last_layer_name, last_layer = list(robust_backbone.named_children())[-1]
48-
setattr(robust_backbone, last_layer_name, nn.Identity())
52+
num_fake_samples = len(trainloader.dataset) // num_classes
4953

5054

5155

52-
embeddings, labels = [], []
5356

54-
with torch.no_grad():
55-
for imgs, lbls in trainloader:
56-
imgs = imgs.to(device, non_blocking=True)
57-
embed = robust_backbone(imgs).cpu() # move to CPU only once per batch
58-
embeddings.append(embed)
59-
labels.append(lbls)
60-
embeddings = torch.cat(embeddings).numpy()
61-
labels = torch.cat(labels).numpy()
57+
embeddings, labels = [], []
6258

59+
with torch.no_grad():
60+
for imgs, lbls in trainloader:
61+
imgs = imgs.to(device, non_blocking=True)
62+
embed = robust_backbone(imgs).cpu() # move to CPU only once per batch
63+
embeddings.append(embed)
64+
labels.append(lbls)
65+
embeddings = torch.cat(embeddings).numpy()
66+
labels = torch.cat(labels).numpy()
6367

64-
print("embedding")
6568

69+
print("embedding computed...")
6670

67-
if fast==False:
68-
gmm_dict = {}
69-
for cls in np.unique(labels):
70-
cls_embed = embeddings[labels == cls]
71-
gmm = GaussianMixture(n_components=1, covariance_type='full').fit(cls_embed)
72-
gmm_dict[cls] = gmm
7371

74-
print("fake start")
72+
if args.fast==False:
73+
gmm_dict = {}
74+
for cls in np.unique(labels):
75+
cls_embed = embeddings[labels == cls]
76+
gmm = GaussianMixture(n_components=1, covariance_type='full').fit(cls_embed)
77+
gmm_dict[cls] = gmm
7578

76-
fake_data = []
79+
print("fake crafing...")
7780

81+
fake_data = []
7882

79-
for cls, gmm in gmm_dict.items():
80-
samples, likelihoods = [], []
81-
while len(samples) < num_samples_needed:
82-
s = gmm.sample(100)[0]
83-
likelihood = gmm.score_samples(s)
84-
samples.append(s[likelihood < np.quantile(likelihood, 0.001)])
85-
likelihoods.append(likelihood[likelihood < np.quantile(likelihood, 0.001)])
86-
if sum(len(smp) for smp in samples) >= num_samples_needed:
87-
break
88-
samples = np.vstack(samples)[:num_samples_needed]
89-
fake_data.append(samples)
9083

91-
fake_data = np.vstack(fake_data)
92-
fake_data = torch.tensor(fake_data).float()
93-
fake_data = F.normalize(fake_data, p=2, dim=1)
84+
for cls, gmm in gmm_dict.items():
85+
samples, likelihoods = [], []
86+
while len(samples) < num_samples_needed:
87+
s = gmm.sample(100)[0]
88+
likelihood = gmm.score_samples(s)
89+
samples.append(s[likelihood < np.quantile(likelihood, 0.001)])
90+
likelihoods.append(likelihood[likelihood < np.quantile(likelihood, 0.001)])
91+
if sum(len(smp) for smp in samples) >= num_samples_needed:
92+
break
93+
samples = np.vstack(samples)[:num_samples_needed]
94+
fake_data.append(samples)
9495

95-
fake_labels = torch.full((fake_data.shape[0],), 10)
96-
fake_loader = DataLoader(TensorDataset(fake_data, fake_labels), batch_size=128, shuffle=True)
96+
fake_data = np.vstack(fake_data)
97+
fake_data = torch.tensor(fake_data).float()
98+
fake_data = F.normalize(fake_data, p=2, dim=1)
9799

98-
if fast==True:
100+
fake_labels = torch.full((fake_data.shape[0],), 10)
101+
fake_loader = DataLoader(TensorDataset(fake_data, fake_labels), batch_size=128, shuffle=True)
99102

103+
if args.fast==True:
100104

101-
noise_std = 0.1 # standard deviation of noise
102-
noisy_embeddings = torch.tensor(embeddings) + noise_std * torch.randn_like(torch.tensor(embeddings))
103105

104-
# Normalize Noisy Embeddings
105-
noisy_embeddings = F.normalize(noisy_embeddings, p=2, dim=1)[:len(trainloader.dataset)//num_classes]
106+
noise_std = 0.1 # standard deviation of noise
107+
noisy_embeddings = torch.tensor(embeddings) + noise_std * torch.randn_like(torch.tensor(embeddings))
106108

107-
# Convert to DataLoader if needed
108-
fake_labels = torch.full((noisy_embeddings.shape[0],), num_classes)[:len(trainloader.dataset)//num_classes]
109-
fake_loader = DataLoader(TensorDataset(noisy_embeddings, fake_labels), batch_size=128, shuffle=True)
109+
# Normalize Noisy Embeddings
110+
noisy_embeddings = F.normalize(noisy_embeddings, p=2, dim=1)[:len(trainloader.dataset)//num_classes]
110111

112+
# Convert to DataLoader if needed
113+
fake_labels = torch.full((noisy_embeddings.shape[0],), num_classes)[:len(trainloader.dataset)//num_classes]
114+
fake_loader = DataLoader(TensorDataset(noisy_embeddings, fake_labels), batch_size=128, shuffle=True)
111115

112116

113-
final_model=stability_loss_function_(trainloader,testloader,robust_backbone,num_classes,fake_loader,last_layer)
117+
final_model = stability_loss_function_(trainloader, testloader, robust_backbone, num_classes, fake_loader, last_layer, args)
114118

115-
116-
117-
attack_eps = 8/255
118-
attack_steps = 10
119-
attack_alpha = 2.5 * attack_eps / attack_steps
120-
test_attack = PGD_AUC(final_model, eps=attack_eps, steps=attack_steps, alpha=attack_alpha, num_classes=num_classes)
119+
120+
test_attack = PGD_AUC(final_model, eps=args.attack_eps, steps=args.attack_steps, alpha=args.attack_alpha, num_classes=num_classes)
121+
get_clean_AUC(final_model, ID_OOD_loader , device, num_classes)
122+
adv_auc = get_auc_adversarial(model=final_model, test_loader=ID_OOD_loader, test_attack=test_attack, device=device, num_classes=num_classes)
121123

122124

123125

124-
get_clean_AUC(final_model, ID_OOD_loader , device, num_classes)
126+
if __name__ == "__main__":
127+
main()
125128

126-
adv_auc = get_auc_adversarial(model=final_model, test_loader=ID_OOD_loader, test_attack=test_attack, device=device, num_classes=num_classes)

0 commit comments

Comments
 (0)