1+ # -*- coding: UTF-8 -*-
2+ import os
3+ # os.environ['CUDA_VISIBLE_DEVICES'] = '0'
4+ # os.environ['TL_BACKEND'] = 'torch'
5+ os .environ ['TF_CPP_MIN_LOG_LEVEL' ] = '2'
6+ # 0:Output all; 1:Filter out INFO; 2:Filter out INFO and WARNING; 3:Filter out INFO, WARNING, and ERROR
7+ import argparse
8+ import numpy as np
9+ import tensorlayerx as tlx
10+ from gammagl .models import RoheHAN
11+ from utils import *
12+ import pickle as pkl
13+ from gammagl .utils import mask_to_index
14+ from gammagl .utils import edge_index_to_adj_matrix
15+ from gammagl .datasets .acm4rohe import ACM4Rohe
16+
17+ class SemiSpvzLoss (tlx .nn .Module ):
18+ def __init__ (self , net , loss_fn ):
19+ super (SemiSpvzLoss , self ).__init__ ()
20+ self .net = net
21+ self .loss_fn = loss_fn
22+
23+ def forward (self , data , y ):
24+ logits = self .net (data ['x_dict' ], data ['edge_index_dict' ], data ['num_nodes_dict' ])
25+ train_logits = tlx .gather (logits ['paper' ], data ['train_idx' ])
26+ train_y = tlx .gather (y , data ['train_idx' ])
27+ loss = self .loss_fn (train_logits , train_y )
28+ return loss
29+
30+ # Evaluate the model, returning loss and accuracy scores
31+ def evaluate (model , data , labels , mask , loss_func ):
32+ model .set_eval ()
33+ logits = model (data ['x_dict' ], data ['edge_index_dict' ], data ['num_nodes_dict' ])
34+ logits = logits ['paper' ] # Focus evaluation on 'paper' nodes
35+ mask_indices = mask # Assuming mask is an array of indices
36+ logits_masked = tlx .gather (logits , tlx .convert_to_tensor (mask_indices , dtype = tlx .int64 ))
37+ labels_masked = tlx .gather (labels , tlx .convert_to_tensor (mask_indices , dtype = tlx .int64 ))
38+ loss = loss_func (logits_masked , labels_masked )
39+
40+ accuracy , micro_f1 , macro_f1 = score (logits_masked , labels_masked )
41+ return loss , accuracy , micro_f1 , macro_f1
42+
43+ def main (args ):
44+ # Load ACM raw dataset
45+ dataname = 'acm'
46+ dataset = ACM4Rohe (root = args .dataset_path )
47+ g = dataset [0 ]
48+ features_dict = {ntype : g [ntype ].x for ntype in g .node_types if hasattr (g [ntype ], 'x' )}
49+ labels = g ['paper' ].y
50+ train_mask = g ['paper' ].train_mask
51+ val_mask = g ['paper' ].val_mask
52+ test_mask = g ['paper' ].test_mask
53+
54+ # Compute number of classes
55+ num_classes = int (tlx .reduce_max (labels )) + 1
56+
57+ # Get train_idx, val_idx, test_idx from masks
58+ train_idx = mask_to_index (train_mask )
59+ val_idx = mask_to_index (val_mask )
60+ test_idx = mask_to_index (test_mask )
61+
62+ x_dict = features_dict
63+ y = labels
64+ features = features_dict ['paper' ]
65+
66+ # Define meta-paths (PAP, PFP)
67+ meta_paths = [[('paper' , 'pa' , 'author' ), ('author' , 'ap' , 'paper' )],
68+ [('paper' , 'pf' , 'field' ), ('field' , 'fp' , 'paper' )]]
69+
70+ # Define initial settings for each edge type
71+ settings = {
72+ ('paper' , 'author' , 'paper' ): {'T' : 3 , 'TransM' : None },
73+ ('paper' , 'field' , 'paper' ): {'T' : 5 , 'TransM' : None },
74+ }
75+
76+ # Prepare adjacency matrices
77+ hete_adjs = {
78+ 'pa' : edge_index_to_adj_matrix (g ['paper' , 'pa' , 'author' ].edge_index , g ['paper' ].num_nodes , g ['author' ].num_nodes ),
79+ 'ap' : edge_index_to_adj_matrix (g ['author' , 'ap' , 'paper' ].edge_index , g ['author' ].num_nodes , g ['paper' ].num_nodes ),
80+ 'pf' : edge_index_to_adj_matrix (g ['paper' , 'pf' , 'field' ].edge_index , g ['paper' ].num_nodes , g ['field' ].num_nodes ),
81+ 'fp' : edge_index_to_adj_matrix (g ['field' , 'fp' , 'paper' ].edge_index , g ['field' ].num_nodes , g ['paper' ].num_nodes )
82+ }
83+ meta_g = dataset .get_meta_graph (dataname , hete_adjs , features_dict , labels , train_mask , val_mask , test_mask )
84+ # Prepare edge index and node count dictionaries
85+ edge_index_dict = {etype : meta_g [etype ].edge_index for etype in meta_g .edge_types }
86+ num_nodes_dict = {ntype : meta_g [ntype ].num_nodes for ntype in meta_g .node_types }
87+
88+ # Compute edge transformation matrices
89+ trans_edge_weights_list = get_transition (hete_adjs , meta_paths , edge_index_dict , meta_g .metadata ()[1 ])
90+ for i , edge_type in enumerate (meta_g .metadata ()[1 ]):
91+ settings [edge_type ]['TransM' ] = trans_edge_weights_list [i ]
92+
93+ layer_settings = [settings , settings ]
94+
95+ # Initialize the RoheHAN model
96+ model = RoheHAN (
97+ metadata = meta_g .metadata (),
98+ in_channels = features .shape [1 ],
99+ hidden_size = args .hidden_units ,
100+ out_size = num_classes ,
101+ num_heads = args .num_heads ,
102+ dropout_rate = args .dropout ,
103+ settings = layer_settings
104+ )
105+
106+ # Define optimizer and loss function
107+ optimizer = tlx .optimizers .Adam (lr = args .lr , weight_decay = args .weight_decay )
108+ loss_func = tlx .losses .softmax_cross_entropy_with_logits
109+ semi_spvz_loss = SemiSpvzLoss (model , loss_func )
110+
111+ # Prepare training components
112+ train_weights = model .trainable_weights
113+ train_one_step = tlx .model .TrainOneStep (semi_spvz_loss , optimizer , train_weights )
114+
115+ # Prepare data dictionary
116+ data = {
117+ "x_dict" : x_dict ,
118+ "edge_index_dict" : edge_index_dict ,
119+ "num_nodes_dict" : num_nodes_dict ,
120+ "train_idx" : tlx .convert_to_tensor (train_idx , dtype = tlx .int64 ),
121+ "val_idx" : tlx .convert_to_tensor (val_idx , dtype = tlx .int64 ),
122+ "test_idx" : tlx .convert_to_tensor (test_idx , dtype = tlx .int64 ),
123+ "y" : y
124+ }
125+
126+ # Training loop
127+ best_val_acc = 0.0
128+
129+ for epoch in range (args .num_epochs ):
130+ model .set_train ()
131+ # Forward and backward pass
132+ loss = train_one_step (data , y )
133+
134+ # Evaluate on validation set
135+ model .set_eval ()
136+ val_loss , val_acc , val_micro_f1 , val_macro_f1 = evaluate (model , data , y , val_idx , loss_func )
137+
138+ print (f"Epoch { epoch + 1 } | Train Loss: { loss .item ():.4f} | Val Micro-F1: { val_micro_f1 :.4f} | Val Macro-F1: { val_macro_f1 :.4f} " )
139+
140+ # Save best model
141+ if val_acc > best_val_acc :
142+ best_val_acc = val_acc
143+ # Save model weights
144+ model .save_weights (os .path .join (args .best_model_path , 'best_model.npz' ), format = 'npz_dict' )
145+
146+ # Load the best model
147+ model .load_weights (os .path .join (args .best_model_path , 'best_model.npz' ), format = 'npz_dict' )
148+
149+ # Test the model
150+ test_loss , test_acc , test_micro_f1 , test_macro_f1 = evaluate (model , data , y , test_idx , loss_func )
151+ print (f"Test Micro-F1: { test_micro_f1 :.4f} | Test Macro-F1: { test_macro_f1 :.4f} " )
152+
153+ # Load target node IDs
154+ print ("Loading target nodes" )
155+ tar_idx = []
156+ # can attack 500 target nodes by seting range(5)
157+ for i in range (1 ):
158+ target_filename = os .path .join (args .dataset_path , f'ACM4Rohe/raw/data/preprocess/target_nodes/acm_r_target{ i } .pkl' )
159+ with open (target_filename , 'rb' ) as f :
160+ tar_tmp = np .sort (pkl .load (f ))
161+ tar_idx .extend (tar_tmp )
162+
163+ # Evaluate on target nodes
164+ model .set_eval ()
165+ logits_dict = model (data ['x_dict' ], data ['edge_index_dict' ], data ['num_nodes_dict' ])
166+ logits_clean = tlx .gather (logits_dict ['paper' ], tlx .convert_to_tensor (tar_idx , dtype = tlx .int64 ))
167+ labels_clean = tlx .gather (y , tlx .convert_to_tensor (tar_idx , dtype = tlx .int64 ))
168+ _ , tar_micro_f1_clean , tar_macro_f1_clean = score (logits_clean , labels_clean )
169+ print (f"Clean data: Micro-F1: { tar_micro_f1_clean :.4f} | Macro-F1: { tar_macro_f1_clean :.4f} " )
170+
171+ # Load adversarial attacks
172+ n_perturbation = 1
173+ adv_filename = os .path .join (args .dataset_path , 'ACM4Rohe/raw/data/generated_attacks' , f'adv_acm_pap_pa_{ n_perturbation } .pkl' )
174+ with open (adv_filename , 'rb' ) as f :
175+ modified_opt = pkl .load (f )
176+
177+ # Apply adversarial attack
178+ logits_adv_list = []
179+ labels_adv_list = []
180+ for items in modified_opt :
181+ target_node = items [0 ]
182+ del_list = items [2 ]
183+ add_list = items [3 ]
184+ if target_node not in tar_idx :
185+ continue
186+
187+ # Modify adjacency matrices for the attack
188+ mod_hete_adj_dict = {}
189+ for key in hete_adjs .keys ():
190+ mod_hete_adj_dict [key ] = hete_adjs [key ].tolil ()
191+
192+ # Delete and add edges
193+ for edge in del_list :
194+ mod_hete_adj_dict ['pa' ][edge [0 ], edge [1 ]] = 0
195+ mod_hete_adj_dict ['ap' ][edge [1 ], edge [0 ]] = 0
196+ for edge in add_list :
197+ mod_hete_adj_dict ['pa' ][edge [0 ], edge [1 ]] = 1
198+ mod_hete_adj_dict ['ap' ][edge [1 ], edge [0 ]] = 1
199+
200+ for key in mod_hete_adj_dict .keys ():
201+ mod_hete_adj_dict [key ] = mod_hete_adj_dict [key ].tocsc ()
202+
203+ # Update edge index dictionary for the attack
204+ edge_index_dict_atk = {}
205+ meta_path_atk = [('paper' , 'author' , 'paper' ), ('paper' , 'field' , 'paper' )]
206+ for idx , edge_type in enumerate (meta_path_atk ):
207+ # Recompute adjacency matrices for the attack
208+ if edge_type == ('paper' , 'author' , 'paper' ):
209+ adj_matrix = mod_hete_adj_dict ['pa' ].dot (mod_hete_adj_dict ['ap' ])
210+ elif edge_type == ('paper' , 'field' , 'paper' ):
211+ adj_matrix = mod_hete_adj_dict ['pf' ].dot (mod_hete_adj_dict ['fp' ])
212+ else :
213+ raise KeyError (f"Unknown edge type { edge_type } " )
214+
215+ src , dst = adj_matrix .nonzero ()
216+ edge_index = np .vstack ((src , dst ))
217+ edge_index_dict_atk [edge_type ] = edge_index
218+
219+ # Update transformation matrices for the attack
220+ trans_edge_weights_list = get_transition (mod_hete_adj_dict , meta_paths , edge_index_dict_atk , meta_path_atk )
221+
222+ for i , edge_type in enumerate (meta_path_atk ):
223+ key = '__' .join (edge_type )
224+ if key in model .layer_list [0 ].gat_layers :
225+ model .layer_list [0 ].gat_layers [key ].settings ['TransM' ] = trans_edge_weights_list [i ]
226+ else :
227+ raise KeyError (f"Edge type key '{ key } ' not found in gat_layers." )
228+
229+ # Prepare modified graph and data
230+ mod_features_dict = {'paper' : features }
231+ g_atk = dataset .get_meta_graph (dataname , mod_hete_adj_dict , mod_features_dict , y , train_mask , val_mask , test_mask )
232+ data_atk = {
233+ "x_dict" : g_atk .x_dict ,
234+ "edge_index_dict" : {etype : g_atk [etype ].edge_index for etype in g_atk .edge_types },
235+ "num_nodes_dict" : {ntype : g_atk [ntype ].num_nodes for ntype in g_atk .node_types },
236+ }
237+
238+ # Run the model on the attacked graph
239+ model .set_eval ()
240+ with no_grad ():
241+ logits_dict_atk = model (data_atk ['x_dict' ], data_atk ['edge_index_dict' ], data_atk ['num_nodes_dict' ])
242+ logits_atk = logits_dict_atk ['paper' ]
243+ logits_adv = tlx .gather (logits_atk , tlx .convert_to_tensor ([target_node ], dtype = tlx .int64 ))
244+ label_adv = tlx .gather (y , tlx .convert_to_tensor ([target_node ], dtype = tlx .int64 ))
245+
246+ logits_adv_list .append (logits_adv )
247+ labels_adv_list .append (label_adv )
248+
249+ logits_adv = tlx .concat (logits_adv_list , axis = 0 )
250+ labels_adv = tlx .concat (labels_adv_list , axis = 0 )
251+
252+ # Evaluate adversarial attack
253+ _ , tar_micro_f1_atk , tar_macro_f1_atk = score (logits_adv , labels_adv )
254+ print (f"Attacked data: Micro-F1: { tar_micro_f1_atk :.4f} | Macro-F1: { tar_macro_f1_atk :.4f} " )
255+
256+ if __name__ == '__main__' :
257+ parser = argparse .ArgumentParser ()
258+ parser .add_argument ("--seed" , type = int , default = 2 , help = "Random seed." )
259+ parser .add_argument ("--lr" , type = float , default = 0.005 , help = "Learning rate." )
260+ parser .add_argument ("--num_heads" , type = int , default = [8 ], help = "Number of attention heads." )
261+ parser .add_argument ("--hidden_units" , type = int , default = 8 , help = "Hidden units." )
262+ parser .add_argument ("--dropout" , type = float , default = 0.6 , help = "Dropout rate." )
263+ parser .add_argument ("--weight_decay" , type = float , default = 0.001 , help = "Weight decay." )
264+ parser .add_argument ("--num_epochs" , type = int , default = 100 , help = "Number of training epochs." )
265+ parser .add_argument ("--gpu" , type = int , default = 0 , help = "GPU index. Use -1 for CPU." )
266+ parser .add_argument ("--dataset_path" , type = str , default = r'' , help = "path to save dataset" )
267+ parser .add_argument ("--best_model_path" , type = str , default = './' , help = "Path to save the best model." )
268+ args = parser .parse_args ()
269+
270+ # Setup configuration
271+ tlx .set_seed (args .seed )
272+ if args .gpu >= 0 :
273+ tlx .set_device ("GPU" , args .gpu )
274+ else :
275+ tlx .set_device ("CPU" )
276+
277+ main (args )
0 commit comments