Skip to content

Commit c9d2183

Browse files
n1108xy-Ji
andauthored
[Model] RoHe (#220)
* update README.md * RoheHAN * 更新readme * 修改 * 修改readme * 更新 * 更新 * update * dataset * update * 应用dataset * 添加rst文档,更新gammagl.utils,更新dataset * update * update readme * 删除多余文件 * 更新路径 --------- Co-authored-by: Xingyuan Ji <[email protected]>
1 parent f21a720 commit c9d2183

File tree

11 files changed

+889
-4
lines changed

11 files changed

+889
-4
lines changed

examples/rohehan/readme.md

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Robust Heterogeneous Graph Neural Network (RoHeHAN)
2+
3+
This is an implementation of `RoHeHAN`, a robust heterogeneous graph neural network designed to defend against adversarial attacks on heterogeneous graphs.
4+
5+
- Paper link: [https://cdn.aaai.org/ojs/20357/20357-13-24370-1-2-20220628.pdf](https://cdn.aaai.org/ojs/20357/20357-13-24370-1-2-20220628.pdf)
6+
- Original paper title: *Robust Heterogeneous Graph Neural Networks against Adversarial Attacks*
7+
- Implemented using `tensorlayerx` and `gammagl` libraries.
8+
9+
## Usage
10+
11+
To reproduce the RoHeHAN results on the ACM dataset, run the following command:
12+
13+
```bash
14+
TL_BACKEND="torch" python rohehan_trainer.py --num_epochs 100 --gpu 0
15+
TL_BACKEND="tensorflow" python rohehan_trainer.py --num_epochs 100 --gpu 0
16+
```
17+
18+
## Performance
19+
20+
Reference performance numbers for the ACM dataset:
21+
22+
| Backend | Clean (no attack) | Attack (1 perturbation) | Attack (3 perturbations) | Attack (5 perturbations) |
23+
| ------- | ----------------- | ----------------------- | ------------------------ | ------------------------ |
24+
| torch | 0.955 | 0.950 | 0.940 | 0.905 |
25+
| tensorflow | 0.965 | 0.935 | 0.910 | 0.905 |
26+
27+
ACM dataset link: [https://github.com/Jhy1993/HAN/raw/master/data/acm/ACM.mat](https://github.com/Jhy1993/HAN/raw/master/data/acm/ACM.mat)
28+
29+
### Example Commands
30+
31+
You can adjust training settings, such as the number of epochs, learning rate, and dropout rate, with the following commands:
32+
33+
```bash
34+
TL_BACKEND="torch" python rohehan_trainer.py --num_epochs 200 --lr 0.005 --dropout 0.6 --gpu 0 --seed 0
35+
```
36+
37+
## Notes
38+
39+
- The `settings` in the RoHeGAT layer control the attention purifier mechanism, which ensures robustness against adversarial attacks by pruning unreliable neighbors.
40+
41+
This implementation builds on the idea of using metapath-based transiting probability and attention purification to improve the robustness of heterogeneous graph neural networks (HGNNs).
42+
43+
## Original Paper Results
44+
45+
The original paper reports the following performance metrics under clean and adversarial settings:
46+
47+
| Dataset | Clean (no attack) | Attack (1 perturbation) | Attack (3 perturbations) | Attack (5 perturbations) |
48+
| ------- | ----------------- | ----------------------- | ------------------------ | ------------------------ |
49+
| ACM | 0.920 | 0.904 | 0.902 | 0.882 |
Lines changed: 277 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,277 @@
1+
# -*- coding: UTF-8 -*-
2+
import os
3+
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'
4+
# os.environ['TL_BACKEND'] = 'torch'
5+
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
6+
# 0:Output all; 1:Filter out INFO; 2:Filter out INFO and WARNING; 3:Filter out INFO, WARNING, and ERROR
7+
import argparse
8+
import numpy as np
9+
import tensorlayerx as tlx
10+
from gammagl.models import RoheHAN
11+
from utils import *
12+
import pickle as pkl
13+
from gammagl.utils import mask_to_index
14+
from gammagl.utils import edge_index_to_adj_matrix
15+
from gammagl.datasets.acm4rohe import ACM4Rohe
16+
17+
class SemiSpvzLoss(tlx.nn.Module):
18+
def __init__(self, net, loss_fn):
19+
super(SemiSpvzLoss, self).__init__()
20+
self.net = net
21+
self.loss_fn = loss_fn
22+
23+
def forward(self, data, y):
24+
logits = self.net(data['x_dict'], data['edge_index_dict'], data['num_nodes_dict'])
25+
train_logits = tlx.gather(logits['paper'], data['train_idx'])
26+
train_y = tlx.gather(y, data['train_idx'])
27+
loss = self.loss_fn(train_logits, train_y)
28+
return loss
29+
30+
# Evaluate the model, returning loss and accuracy scores
31+
def evaluate(model, data, labels, mask, loss_func):
32+
model.set_eval()
33+
logits = model(data['x_dict'], data['edge_index_dict'], data['num_nodes_dict'])
34+
logits = logits['paper'] # Focus evaluation on 'paper' nodes
35+
mask_indices = mask # Assuming mask is an array of indices
36+
logits_masked = tlx.gather(logits, tlx.convert_to_tensor(mask_indices, dtype=tlx.int64))
37+
labels_masked = tlx.gather(labels, tlx.convert_to_tensor(mask_indices, dtype=tlx.int64))
38+
loss = loss_func(logits_masked, labels_masked)
39+
40+
accuracy, micro_f1, macro_f1 = score(logits_masked, labels_masked)
41+
return loss, accuracy, micro_f1, macro_f1
42+
43+
def main(args):
44+
# Load ACM raw dataset
45+
dataname = 'acm'
46+
dataset = ACM4Rohe(root = args.dataset_path)
47+
g = dataset[0]
48+
features_dict = {ntype: g[ntype].x for ntype in g.node_types if hasattr(g[ntype], 'x')}
49+
labels = g['paper'].y
50+
train_mask = g['paper'].train_mask
51+
val_mask = g['paper'].val_mask
52+
test_mask = g['paper'].test_mask
53+
54+
# Compute number of classes
55+
num_classes = int(tlx.reduce_max(labels)) + 1
56+
57+
# Get train_idx, val_idx, test_idx from masks
58+
train_idx = mask_to_index(train_mask)
59+
val_idx = mask_to_index(val_mask)
60+
test_idx = mask_to_index(test_mask)
61+
62+
x_dict = features_dict
63+
y = labels
64+
features = features_dict['paper']
65+
66+
# Define meta-paths (PAP, PFP)
67+
meta_paths = [[('paper', 'pa', 'author'), ('author', 'ap', 'paper')],
68+
[('paper', 'pf', 'field'), ('field', 'fp', 'paper')]]
69+
70+
# Define initial settings for each edge type
71+
settings = {
72+
('paper', 'author', 'paper'): {'T': 3, 'TransM': None},
73+
('paper', 'field', 'paper'): {'T': 5, 'TransM': None},
74+
}
75+
76+
# Prepare adjacency matrices
77+
hete_adjs = {
78+
'pa': edge_index_to_adj_matrix(g['paper', 'pa', 'author'].edge_index, g['paper'].num_nodes, g['author'].num_nodes),
79+
'ap': edge_index_to_adj_matrix(g['author', 'ap', 'paper'].edge_index, g['author'].num_nodes, g['paper'].num_nodes),
80+
'pf': edge_index_to_adj_matrix(g['paper', 'pf', 'field'].edge_index, g['paper'].num_nodes, g['field'].num_nodes),
81+
'fp': edge_index_to_adj_matrix(g['field', 'fp', 'paper'].edge_index, g['field'].num_nodes, g['paper'].num_nodes)
82+
}
83+
meta_g = dataset.get_meta_graph(dataname, hete_adjs, features_dict, labels, train_mask, val_mask, test_mask)
84+
# Prepare edge index and node count dictionaries
85+
edge_index_dict = {etype: meta_g[etype].edge_index for etype in meta_g.edge_types}
86+
num_nodes_dict = {ntype: meta_g[ntype].num_nodes for ntype in meta_g.node_types}
87+
88+
# Compute edge transformation matrices
89+
trans_edge_weights_list = get_transition(hete_adjs, meta_paths, edge_index_dict, meta_g.metadata()[1])
90+
for i, edge_type in enumerate(meta_g.metadata()[1]):
91+
settings[edge_type]['TransM'] = trans_edge_weights_list[i]
92+
93+
layer_settings = [settings, settings]
94+
95+
# Initialize the RoheHAN model
96+
model = RoheHAN(
97+
metadata=meta_g.metadata(),
98+
in_channels=features.shape[1],
99+
hidden_size=args.hidden_units,
100+
out_size=num_classes,
101+
num_heads=args.num_heads,
102+
dropout_rate=args.dropout,
103+
settings=layer_settings
104+
)
105+
106+
# Define optimizer and loss function
107+
optimizer = tlx.optimizers.Adam(lr=args.lr, weight_decay=args.weight_decay)
108+
loss_func = tlx.losses.softmax_cross_entropy_with_logits
109+
semi_spvz_loss = SemiSpvzLoss(model, loss_func)
110+
111+
# Prepare training components
112+
train_weights = model.trainable_weights
113+
train_one_step = tlx.model.TrainOneStep(semi_spvz_loss, optimizer, train_weights)
114+
115+
# Prepare data dictionary
116+
data = {
117+
"x_dict": x_dict,
118+
"edge_index_dict": edge_index_dict,
119+
"num_nodes_dict": num_nodes_dict,
120+
"train_idx": tlx.convert_to_tensor(train_idx, dtype=tlx.int64),
121+
"val_idx": tlx.convert_to_tensor(val_idx, dtype=tlx.int64),
122+
"test_idx": tlx.convert_to_tensor(test_idx, dtype=tlx.int64),
123+
"y": y
124+
}
125+
126+
# Training loop
127+
best_val_acc = 0.0
128+
129+
for epoch in range(args.num_epochs):
130+
model.set_train()
131+
# Forward and backward pass
132+
loss = train_one_step(data, y)
133+
134+
# Evaluate on validation set
135+
model.set_eval()
136+
val_loss, val_acc, val_micro_f1, val_macro_f1 = evaluate(model, data, y, val_idx, loss_func)
137+
138+
print(f"Epoch {epoch+1} | Train Loss: {loss.item():.4f} | Val Micro-F1: {val_micro_f1:.4f} | Val Macro-F1: {val_macro_f1:.4f}")
139+
140+
# Save best model
141+
if val_acc > best_val_acc:
142+
best_val_acc = val_acc
143+
# Save model weights
144+
model.save_weights(os.path.join(args.best_model_path, 'best_model.npz'), format='npz_dict')
145+
146+
# Load the best model
147+
model.load_weights(os.path.join(args.best_model_path, 'best_model.npz'), format='npz_dict')
148+
149+
# Test the model
150+
test_loss, test_acc, test_micro_f1, test_macro_f1 = evaluate(model, data, y, test_idx, loss_func)
151+
print(f"Test Micro-F1: {test_micro_f1:.4f} | Test Macro-F1: {test_macro_f1:.4f}")
152+
153+
# Load target node IDs
154+
print("Loading target nodes")
155+
tar_idx = []
156+
# can attack 500 target nodes by seting range(5)
157+
for i in range(1):
158+
target_filename = os.path.join(args.dataset_path, f'ACM4Rohe/raw/data/preprocess/target_nodes/acm_r_target{i}.pkl')
159+
with open(target_filename, 'rb') as f:
160+
tar_tmp = np.sort(pkl.load(f))
161+
tar_idx.extend(tar_tmp)
162+
163+
# Evaluate on target nodes
164+
model.set_eval()
165+
logits_dict = model(data['x_dict'], data['edge_index_dict'], data['num_nodes_dict'])
166+
logits_clean = tlx.gather(logits_dict['paper'], tlx.convert_to_tensor(tar_idx, dtype=tlx.int64))
167+
labels_clean = tlx.gather(y, tlx.convert_to_tensor(tar_idx, dtype=tlx.int64))
168+
_, tar_micro_f1_clean, tar_macro_f1_clean = score(logits_clean, labels_clean)
169+
print(f"Clean data: Micro-F1: {tar_micro_f1_clean:.4f} | Macro-F1: {tar_macro_f1_clean:.4f}")
170+
171+
# Load adversarial attacks
172+
n_perturbation = 1
173+
adv_filename = os.path.join(args.dataset_path, 'ACM4Rohe/raw/data/generated_attacks', f'adv_acm_pap_pa_{n_perturbation}.pkl')
174+
with open(adv_filename, 'rb') as f:
175+
modified_opt = pkl.load(f)
176+
177+
# Apply adversarial attack
178+
logits_adv_list = []
179+
labels_adv_list = []
180+
for items in modified_opt:
181+
target_node = items[0]
182+
del_list = items[2]
183+
add_list = items[3]
184+
if target_node not in tar_idx:
185+
continue
186+
187+
# Modify adjacency matrices for the attack
188+
mod_hete_adj_dict = {}
189+
for key in hete_adjs.keys():
190+
mod_hete_adj_dict[key] = hete_adjs[key].tolil()
191+
192+
# Delete and add edges
193+
for edge in del_list:
194+
mod_hete_adj_dict['pa'][edge[0], edge[1]] = 0
195+
mod_hete_adj_dict['ap'][edge[1], edge[0]] = 0
196+
for edge in add_list:
197+
mod_hete_adj_dict['pa'][edge[0], edge[1]] = 1
198+
mod_hete_adj_dict['ap'][edge[1], edge[0]] = 1
199+
200+
for key in mod_hete_adj_dict.keys():
201+
mod_hete_adj_dict[key] = mod_hete_adj_dict[key].tocsc()
202+
203+
# Update edge index dictionary for the attack
204+
edge_index_dict_atk = {}
205+
meta_path_atk = [('paper', 'author', 'paper'), ('paper', 'field', 'paper')]
206+
for idx, edge_type in enumerate(meta_path_atk):
207+
# Recompute adjacency matrices for the attack
208+
if edge_type == ('paper', 'author', 'paper'):
209+
adj_matrix = mod_hete_adj_dict['pa'].dot(mod_hete_adj_dict['ap'])
210+
elif edge_type == ('paper', 'field', 'paper'):
211+
adj_matrix = mod_hete_adj_dict['pf'].dot(mod_hete_adj_dict['fp'])
212+
else:
213+
raise KeyError(f"Unknown edge type {edge_type}")
214+
215+
src, dst = adj_matrix.nonzero()
216+
edge_index = np.vstack((src, dst))
217+
edge_index_dict_atk[edge_type] = edge_index
218+
219+
# Update transformation matrices for the attack
220+
trans_edge_weights_list = get_transition(mod_hete_adj_dict, meta_paths, edge_index_dict_atk, meta_path_atk)
221+
222+
for i, edge_type in enumerate(meta_path_atk):
223+
key = '__'.join(edge_type)
224+
if key in model.layer_list[0].gat_layers:
225+
model.layer_list[0].gat_layers[key].settings['TransM'] = trans_edge_weights_list[i]
226+
else:
227+
raise KeyError(f"Edge type key '{key}' not found in gat_layers.")
228+
229+
# Prepare modified graph and data
230+
mod_features_dict = {'paper': features}
231+
g_atk = dataset.get_meta_graph(dataname, mod_hete_adj_dict, mod_features_dict, y, train_mask, val_mask, test_mask)
232+
data_atk = {
233+
"x_dict": g_atk.x_dict,
234+
"edge_index_dict": {etype: g_atk[etype].edge_index for etype in g_atk.edge_types},
235+
"num_nodes_dict": {ntype: g_atk[ntype].num_nodes for ntype in g_atk.node_types},
236+
}
237+
238+
# Run the model on the attacked graph
239+
model.set_eval()
240+
with no_grad():
241+
logits_dict_atk = model(data_atk['x_dict'], data_atk['edge_index_dict'], data_atk['num_nodes_dict'])
242+
logits_atk = logits_dict_atk['paper']
243+
logits_adv = tlx.gather(logits_atk, tlx.convert_to_tensor([target_node], dtype=tlx.int64))
244+
label_adv = tlx.gather(y, tlx.convert_to_tensor([target_node], dtype=tlx.int64))
245+
246+
logits_adv_list.append(logits_adv)
247+
labels_adv_list.append(label_adv)
248+
249+
logits_adv = tlx.concat(logits_adv_list, axis=0)
250+
labels_adv = tlx.concat(labels_adv_list, axis=0)
251+
252+
# Evaluate adversarial attack
253+
_, tar_micro_f1_atk, tar_macro_f1_atk = score(logits_adv, labels_adv)
254+
print(f"Attacked data: Micro-F1: {tar_micro_f1_atk:.4f} | Macro-F1: {tar_macro_f1_atk:.4f}")
255+
256+
if __name__ == '__main__':
257+
parser = argparse.ArgumentParser()
258+
parser.add_argument("--seed", type=int, default=2, help="Random seed.")
259+
parser.add_argument("--lr", type=float, default=0.005, help="Learning rate.")
260+
parser.add_argument("--num_heads", type=int, default=[8], help="Number of attention heads.")
261+
parser.add_argument("--hidden_units", type=int, default=8, help="Hidden units.")
262+
parser.add_argument("--dropout", type=float, default=0.6, help="Dropout rate.")
263+
parser.add_argument("--weight_decay", type=float, default=0.001, help="Weight decay.")
264+
parser.add_argument("--num_epochs", type=int, default=100, help="Number of training epochs.")
265+
parser.add_argument("--gpu", type=int, default=0, help="GPU index. Use -1 for CPU.")
266+
parser.add_argument("--dataset_path", type=str, default=r'', help="path to save dataset")
267+
parser.add_argument("--best_model_path", type=str, default='./', help="Path to save the best model.")
268+
args = parser.parse_args()
269+
270+
# Setup configuration
271+
tlx.set_seed(args.seed)
272+
if args.gpu >= 0:
273+
tlx.set_device("GPU", args.gpu)
274+
else:
275+
tlx.set_device("CPU")
276+
277+
main(args)

examples/rohehan/utils.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# -*- coding: UTF-8 -*-
2+
import os
3+
import numpy as np
4+
import tensorlayerx as tlx
5+
from sklearn.metrics import f1_score
6+
import scipy.sparse as sp
7+
from contextlib import nullcontext
8+
9+
# Evaluation function for accuracy and F1-score
10+
def score(logits, labels):
11+
predictions = tlx.argmax(logits, axis=1)
12+
predictions = tlx.convert_to_numpy(predictions)
13+
labels = tlx.convert_to_numpy(labels)
14+
15+
accuracy = np.sum(predictions == labels) / len(predictions)
16+
micro_f1 = f1_score(labels, predictions, average='micro')
17+
macro_f1 = f1_score(labels, predictions, average='macro')
18+
return accuracy, micro_f1, macro_f1
19+
20+
# Compute the transition matrix for edge types based on meta-paths
21+
def get_transition(given_hete_adjs, metapath_info, edge_index_dict, edge_types):
22+
hete_adj_dict_tmp = {}
23+
for key in given_hete_adjs.keys():
24+
deg = given_hete_adjs[key].sum(1).A1
25+
deg_inv = 1 / np.where(deg > 0, deg, 1)
26+
deg_inv_mat = sp.diags(deg_inv)
27+
hete_adj_dict_tmp[key] = deg_inv_mat.dot(given_hete_adjs[key])
28+
29+
trans_edge_weights_list = []
30+
for i, metapath in enumerate(metapath_info):
31+
adj = hete_adj_dict_tmp[metapath[0][1]]
32+
for etype in metapath[1:]:
33+
adj = adj.dot(hete_adj_dict_tmp[etype[1]])
34+
35+
edge_type = edge_types[i]
36+
edge_index = edge_index_dict[edge_type]
37+
38+
edge_trans_values = adj[edge_index[0], edge_index[1]].A1
39+
trans_edge_weights_list.append(edge_trans_values)
40+
return trans_edge_weights_list
41+
42+
43+
# Disable gradient computation
44+
def no_grad():
45+
if tlx.BACKEND == 'torch':
46+
import torch
47+
return torch.no_grad()
48+
elif tlx.BACKEND == 'tensorflow':
49+
return nullcontext()
50+
elif tlx.BACKEND == 'paddle':
51+
import paddle
52+
return paddle.no_grad()
53+
elif tlx.BACKEND == 'mindspore':
54+
import mindspore
55+
return mindspore.context.set_context(mode=mindspore.context.PYNATIVE_MODE)
56+
else:
57+
raise NotImplementedError(f"Unsupported backend: {tlx.BACKEND}")

0 commit comments

Comments
 (0)