-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathrun_pgd_gin.py
More file actions
110 lines (89 loc) · 3.82 KB
/
run_pgd_gin.py
File metadata and controls
110 lines (89 loc) · 3.82 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
# Copyright (C) king.com Ltd 2025
# License: Apache 2.0
"""
Main script to run the PGD Attack on the GIN
"""
import argparse
import copy
import numpy as np
from tqdm import tqdm
import os
import random
import pandas as pd
import torch
import torch.nn.functional as F
from src.models.gin import GIN
from src.models.utils import train_function, test_function
from src.models.utils import classification_loss, normalize_tensor_adj
from datasets.loader import data_loader
from src.attacks.topological_attacks import pgd_attack
from src.utils import set_seed
import warnings
warnings.filterwarnings('ignore')
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--name_dataset', type=str, default='DD',
help='Data set')
parser.add_argument('--model', type=str, default='GCN', help='Model type')
parser.add_argument('--hidden_dim', type=int, default=32,
help='Hidden dim')
parser.add_argument('--pooling', type=str, default='rs-pool',
help='Pooling Type')
parser.add_argument('--value_alpha', type=float, default=20,
help='Temperature Value of the RS-Pool')
args = parser.parse_args()
set_seed(42)
# Load dataset
data = data_loader(args.name_dataset)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
batch_size = 32
training_epochs = 101
lr = 1e-03
attack_epochs = 100
budget = 0.3
fold = 1
# Let's init the GCN
model = GIN(data.input_dim, args.hidden_dim,
data.num_classes, device, pooling=args.pooling,
value_alpha=args.value_alpha).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
loss_function = F.nll_loss
# Train model on the current fold
model = train_function(model, optimizer, data,
fold, device, num_epochs=training_epochs,
batch_size=batch_size)
# Get the test data
_, _, _, _, _, _, Adj_test, X_test, y_test = data.get_fold_data(fold)
# Evaluate the model on the clean test set
test_acc = test_function(model, data, Adj_test,
X_test, y_test, device,
batch_size=batch_size, verbose=True)
clean_acc = test_acc.item() if hasattr(test_acc, "item") else test_acc
print(f"Clean Accuracy: {clean_acc}")
success = 0
correct = 0
for i in tqdm(range(len(Adj_test)), desc=f"Attack budget {budget}"):
ori_adj, x, y = Adj_test[i], X_test[i], y_test[i]
n_perturbations = int(ori_adj.sum() // 2 * budget) + 1
ori_adj = ori_adj.to(device)
x = x.to(device)
y = y.to(device)
# Normalize the original adjacency matrix and predict the class
pred = model.predict(copy.deepcopy(ori_adj), x)
# Only attack if the original prediction is correct
if pred.detach().max(1)[1] == y:
correct += 1
attacker = pgd_attack(ori_adj, x, y, model, device)
attacker.attack(attack_epochs, n_perturbations)
adj_attacked = attacker.modified_adj
pred_attacked = model.predict(adj_attacked, x)
# Check if it was mis-classified
if pred_attacked.detach().max(1)[1] != y:
success += 1
# Compute the final metrics of the experiment
attack_success_rate = success / len(Adj_test)
acc_rate = correct / len(Adj_test)
attacked_acc = 1 - ((1 - acc_rate) + attack_success_rate)
# Print the acc_rate along with the other metrics
print(f"Budget: {budget} - Success Rate: {attack_success_rate}, \
Accuracy Rate: {acc_rate}, Attacked Acc: {attacked_acc}")