Skip to content

Commit ae54006

Browse files
authored
Update stability_loss_function.py
1 parent 4eb73a4 commit ae54006

File tree

1 file changed

+64
-165
lines changed

1 file changed

+64
-165
lines changed

AROS/stability_loss_function.py

Lines changed: 64 additions & 165 deletions
Original file line numberDiff line numberDiff line change
@@ -6,155 +6,38 @@
66
from utils import *
77
from torch.utils.data import DataLoader, Dataset, TensorDataset, Subset, SubsetRandomSampler, ConcatDataset
88
import numpy as np
9-
from tqdm import tqdm
9+
from tqdm.notebook import tqdm
10+
from torch.optim.lr_scheduler import StepLR
1011

1112

12-
weight_diag = 10
13-
weight_offdiag = 0
14-
weight_f = 0.1
15-
16-
weight_norm = 0
17-
weight_lossc = 0
18-
19-
exponent = 1.0
20-
exponent_off = 0.1
21-
exponent_f = 50
22-
time_df = 1
23-
trans = 1.0
24-
transoffdig = 1.0
25-
numm = 16
26-
27-
batches_per_epoch = 128
28-
29-
ODE_FC_odebatch = 64
30-
epoch1=1
31-
epoch2=1
32-
epoch3=1
33-
13+
3414

35-
robust_feature_savefolder = './CIFAR10_resnet_Nov_1'
36-
train_savepath='./CIFAR10_train_resnetNov1.npz'
37-
test_savepath='./CIFAR10_test_resnetNov1.npz'
38-
ODE_FC_save_folder = './CIFAR10_resnet_Nov_1'
15+
robust_feature_savefolder = './CIFAR100_resnet_Nov_1'
16+
train_savepath='./CIFAR100_train_resnetNov1.npz'
17+
test_savepath='./CIFAR100_test_resnetNov1.npz'
18+
ODE_FC_save_folder = './CIFAR100_resnet_Nov_1'
3919

4020

4121

4222
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
4323

4424

4525

46-
class ODEBlocktemp(nn.Module):
47-
def __init__(self, odefunc):
48-
super(ODEBlocktemp, self).__init__()
49-
self.odefunc = odefunc
50-
self.integration_time = torch.tensor([0, 5]).float()
51-
def forward(self, x):
52-
out = self.odefunc(0, x)
53-
return out
54-
@property
55-
def nfe(self):
56-
return self.odefunc.nfe
57-
@nfe.setter
58-
def nfe(self, value):
59-
self.odefunc.nfe = value
60-
61-
62-
def accuracy(model, dataset_loader):
63-
total_correct = 0
64-
for x, y in dataset_loader:
65-
x = x.to(device)
66-
y = one_hot(np.array(y.numpy()), 10)
67-
68-
69-
target_class = np.argmax(y, axis=1)
70-
predicted_class = np.argmax(model(x).cpu().detach().numpy(), axis=1)
71-
total_correct += np.sum(predicted_class == target_class)
72-
return total_correct / len(dataset_loader.dataset)
73-
74-
75-
76-
def save_training_feature(model, dataset_loader, fake_embeddings_loader=None ):
77-
x_save = []
78-
y_save = []
79-
modulelist = list(model)
80-
81-
for x, y in dataset_loader:
82-
x = x.to(device)
83-
y_ = y.numpy() # No need to use np.array here
84-
85-
# Forward pass through the model up to the desired layer
86-
for l in modulelist[0:2]:
87-
x = l(x)
88-
xo = x
89-
90-
x_ = xo.cpu().detach().numpy()
91-
x_save.append(x_)
92-
y_save.append(y_)
93-
94-
# Processing fake embeddings if provided
95-
if fake_embeddings_loader is not None:
96-
for x, y in fake_embeddings_loader:
97-
x = x.to(device)
98-
y_ = y.numpy() # No need to use np.array here
99-
100-
# Forward pass through the model up to the desired layer
101-
for l in modulelist[1:2]:
102-
x = l(x)
103-
xo = x
104-
105-
x_ = xo.cpu().detach().numpy()
106-
x_save.append(x_)
107-
y_save.append(y_)
108-
109-
# Concatenate all collected data before saving
110-
x_save = np.concatenate(x_save)
111-
y_save = np.concatenate(y_save)
112-
113-
114-
# Save the concatenated arrays to a file
115-
np.savez(train_savepath, x_save=x_save, y_save=y_save)
116-
117-
118-
119-
def save_testing_feature(model, dataset_loader):
120-
x_save = []
121-
y_save = []
122-
modulelist = list(model)
123-
layernum = 0
124-
for x, y in dataset_loader:
125-
x = x.to(device)
126-
y_ = np.array(y.numpy())
127-
128-
for l in modulelist[0:2]:
129-
x = l(x)
130-
xo = x
131-
x_ = xo.cpu().detach().numpy()
132-
x_save.append(x_)
133-
y_save.append(y_)
134-
135-
x_save = np.concatenate(x_save)
136-
y_save = np.concatenate(y_save)
137-
138-
np.savez(test_savepath, x_save=x_save, y_save=y_save)
139-
140-
141-
142-
143-
144-
145-
146-
14726

14827

14928

29+
ODE_FC_odebatch=100
15030
def stability_loss_function_(trainloader,testloader,robust_backbone,class_numbers,fake_loader,last_layer,args):
15131

15232

15333
robust_backbone = load_model(model_name=args.model_name, dataset=args.in_dataset, threat_model=args.threat_model).to(device)
34+
35+
15436
last_layer_name, last_layer = list(robust_backbone.named_children())[-1]
15537
setattr(robust_backbone, last_layer_name, nn.Identity())
15638

15739

40+
15841

15942
robust_backbone_fc_features = MLP_OUT_ORTH1024(last_layer.in_features)
16043

@@ -175,11 +58,13 @@ def stability_loss_function_(trainloader,testloader,robust_backbone,class_number
17558
net_save_robustfeature = net_save_robustfeature.to(device)
17659
data_gen = inf_generator(trainloader)
17760
batches_per_epoch = len(trainloader)
178-
best_acc = 0
179-
criterion = nn.CrossEntropyLoss()
18061
optimizer1 = torch.optim.Adam(net_save_robustfeature.parameters(), lr=5e-3, eps=1e-2, amsgrad=True)
62+
scheduler = StepLR(optimizer1, step_size=1, gamma=0.5) # Adjust step_size and gamma as needed
18163

64+
18265
def train_save_robustfeature(epoch):
66+
best_acc = 0
67+
criterion = nn.CrossEntropyLoss()
18368
print('\nEpoch: %d' % epoch)
18469
net_save_robustfeature.train()
18570
train_loss = 0
@@ -198,7 +83,10 @@ def train_save_robustfeature(epoch):
19883
total += targets.size(0)
19984
correct += predicted.eq(targets).sum().item()
20085
progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))
86+
scheduler.step()
20187

88+
89+
20290
def test_save_robustfeature(epoch):
20391
best_acc=0
20492
net_save_robustfeature.eval()
@@ -275,53 +163,55 @@ def f_regularizer(odefunc, z):
275163
test_loader_ODE = DataLoader(DenseDatasetTest(),batch_size=ODE_FC_odebatch,shuffle=True, num_workers=2)
276164
data_gen = inf_generator(train_loader_ODE)
277165
batches_per_epoch = len(train_loader_ODE)
278-
166+
279167

280168

281169

282170
optimizer2 = torch.optim.Adam(ODE_FCmodel.parameters(), lr=1e-2, eps=1e-3, amsgrad=True)
283171

284172

285-
173+
scheduler = StepLR(optimizer2, step_size=1, gamma=0.5) # Adjust step_size and gamma as needed
286174
for epoch in range(args.epoch2):
287-
for itr in tqdm(range(args.epoch2 * batches_per_epoch), desc="Training ODE block with loss function"):
288-
optimizer2.zero_grad()
289-
x, y = data_gen.__next__()
290-
x = x.to(device)
291-
292-
modulelist = list(ODE_FCmodel)
293-
y0 = x
294-
x = modulelist[0](x)
295-
y1 = x
296-
297-
y00 = y0
298-
regu1, regu2 = df_dz_regularizer(odefunc, y00)
299-
regu1 = regu1.mean()
300-
regu2 = regu2.mean()
301-
302-
regu3 = f_regularizer(odefunc, y00)
303-
regu3 = regu3.mean()
304-
305-
loss = weight_f*regu3 + weight_diag*regu1 + weight_offdiag*regu2
306-
307-
loss.backward()
308-
optimizer2.step()
309-
torch.cuda.empty_cache()
310-
311-
tqdm.write(f"Loss: {loss.item()}")
312-
175+
with tqdm(total=args.epoch2 * batches_per_epoch, desc="Training ODE block with loss function") as pbar:
176+
for itr in range(args.epoch2 * batches_per_epoch):
177+
optimizer2.zero_grad()
178+
x, y = data_gen.__next__()
179+
x = x.to(device)
180+
181+
modulelist = list(ODE_FCmodel)
182+
y0 = x
183+
x = modulelist[0](x)
184+
y1 = x
185+
186+
y00 = y0
187+
regu1, regu2 = df_dz_regularizer(odefunc, y00)
188+
regu1 = regu1.mean()
189+
regu2 = regu2.mean()
190+
191+
regu3 = f_regularizer(odefunc, y00)
192+
regu3 = regu3.mean()
193+
194+
loss = weight_f*regu3 + weight_diag*regu1 + weight_offdiag*regu2
195+
196+
loss.backward()
197+
optimizer2.step()
198+
torch.cuda.empty_cache()
199+
200+
# Set postfix to update progress bar with current loss
201+
pbar.set_postfix({"Loss": loss.item()})
202+
pbar.update(1)
203+
print("Loss", loss.item())
204+
scheduler.step() # Update the learning rate
313205

314206
current_lr = optimizer2.param_groups[0]['lr']
315207
tqdm.write(f"Epoch {epoch+1}, Learning Rate: {current_lr}")
316208

317209

318210

319211

212+
320213

321-
def one_hot(x, K):
322-
return np.array(x[:, None] == np.arange(K)[None, :], dtype=int)
323-
324-
214+
325215

326216
feature_layers = ODEBlock(odefunc)
327217
fc_layers = MLP_OUT_LINEAR(class_numbers)
@@ -335,10 +225,11 @@ def one_hot(x, K):
335225
param.requires_grad = False
336226

337227
new_model_full = nn.Sequential(robust_backbone, robust_backbone_fc_features, ODE_FCmodel).to(device)
338-
optimizer3 = torch.optim.Adam([{'params': odefunc.parameters(), 'lr': 1e-5, 'eps':1e-6,},{'params': fc_layers.parameters(), 'lr': 1e-2, 'eps':1e-4,}], amsgrad=True)
339-
criterion = nn.CrossEntropyLoss()
228+
optimizer3 = torch.optim.Adam([{'params': odefunc.parameters(), 'lr': 1e-5, 'eps':1e-6,},{'params': fc_layers.parameters(), 'lr': 5e-3, 'eps':1e-4,}], amsgrad=True)
229+
340230

341231
def train(net, epoch):
232+
criterion = nn.CrossEntropyLoss()
342233
print('\nEpoch: %d' % epoch)
343234
net.train()
344235
train_loss = 0
@@ -357,7 +248,15 @@ def train(net, epoch):
357248
total += targets.size(0)
358249
correct += predicted.eq(targets).sum().item()
359250
progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))
360-
251+
361252
for epoch in range(0, args.epoch3):
362253
train(new_model_full, epoch)
363254
return new_model_full
255+
256+
257+
258+
259+
260+
261+
262+

0 commit comments

Comments
 (0)