1+ import os
2+ os .environ ['TL_BACKEND' ] = 'torch'
3+ os .environ ['TF_CPP_MIN_LOG_LEVEL' ] = '2'
4+ # 0:Output all; 1:Filter out INFO; 2:Filter out INFO and WARNING; 3:Filter out INFO, WARNING, and ERROR
5+
6+ import argparse
7+ import tensorlayerx as tlx
8+ from gammagl .datasets import Planetoid
9+ from gammagl .utils import add_self_loops , mask_to_index , calc_gcn_norm
10+ import time
11+ import scipy .sparse as sp
12+ from gammagl .models import GCNModel , GCNIIModel , GATModel
13+ import tensorlayerx .nn as nn
14+ from tensorlayerx .model import WithLoss , TrainOneStep
15+ from utils import CITModule , dense_mincut_pool , edge_index_to_csr_matrix , AssignmentMatricsMLP , F1score , dense_to_sparse , calculate_acc , get_model
16+
17+ log_softmax = nn .activation .LogSoftmax (dim = - 1 )
18+
19+ class SemiSpvzLoss (WithLoss ):
20+ def __init__ (self , net , loss_fn , cit_module , mlp , adj_matrix ):
21+ super (SemiSpvzLoss , self ).__init__ (backbone = net , loss_fn = loss_fn )
22+ self .cit_module = cit_module
23+ self .mlp = mlp
24+ self .adj_matrix = adj_matrix
25+
26+ def forward (self , data , label ):
27+ if isinstance (self .backbone_network , GCNModel ):
28+ intermediate_features = self .backbone_network .conv [0 ](data ['x' ], data ['edge_index' ], None , data ["num_nodes" ])
29+ intermediate_features = self .backbone_network .relu (intermediate_features )
30+ logits = log_softmax (self .backbone_network (data ['x' ], data ['edge_index' ], None , data ["num_nodes" ]))
31+ elif isinstance (self .backbone_network , GATModel ):
32+ intermediate_features = log_softmax (self .backbone_network (data ['x' ], data ['edge_index' ], data ["num_nodes" ]))
33+ logits = log_softmax (self .backbone_network (data ['x' ], data ['edge_index' ], data ["num_nodes" ]))
34+ elif isinstance (self .backbone_network , GCNIIModel ):
35+ intermediate_features = self .backbone_network (data ['x' ], data ['edge_index' ], data ["edge_weight" ], data ["num_nodes" ])
36+ logits = log_softmax (self .backbone_network (data ['x' ], data ['edge_index' ], data ['edge_weight' ], data ["num_nodes" ]))
37+
38+ if not isinstance (self .backbone_network , GCNModel ):
39+ intermediate_features = nn .Linear (in_features = intermediate_features .shape [1 ], out_features = args .hidden_dim , name = 'linear_1' , act = tlx .relu )(intermediate_features )
40+
41+ train_logits = tlx .gather (logits , data ['train_idx' ])
42+ train_y = tlx .gather (data ['y' ], data ['train_idx' ])
43+
44+ #CIT
45+ assignment_matrics , _ = self .cit_module .forward (intermediate_features , self .mlp )
46+ pooled_x , pooled_adj , mc_loss , o_loss = dense_mincut_pool (data ['x' ], self .adj_matrix , assignment_matrics )
47+ loss = 0.55 * self ._loss_fn (train_logits , train_y ) + 0.25 * mc_loss + 0.2 * o_loss
48+ # loss = self._loss_fn(train_logits, train_y)
49+ return loss
50+
51+ def test (net , data , args , metrics ):
52+ net .load_weights (args .best_model_path + net .name + ".npz" , format = 'npz_dict' )
53+ net .set_eval ()
54+
55+ adj_add = sp .load_npz (f"./datasets/{ args .dataset } _add_{ args .ss } .npz" )
56+ adj_add = dense_to_sparse (adj_add )
57+
58+ edge_weight = tlx .convert_to_tensor (calc_gcn_norm (adj_add , data ["num_nodes" ]))
59+
60+ if isinstance (net , GATModel ):
61+ logits = log_softmax (net (data ['x' ], adj_add , data ["num_nodes" ]))
62+ elif isinstance (net , GCNModel ):
63+ logits = log_softmax (net (data ['x' ], adj_add , None , data ["num_nodes" ]))
64+ elif isinstance (net , GCNIIModel ):
65+ logits = log_softmax (net (data ['x' ], adj_add , edge_weight , data ['num_nodes' ]))
66+
67+ test_logits = tlx .gather (logits , data ['test_idx' ])
68+ test_y = tlx .gather (data ['y' ], data ['test_idx' ])
69+
70+ acc_test = calculate_acc (test_logits , test_y , metrics )
71+ f1 = F1score (test_logits , test_y )
72+
73+ print ('Test set results:' ,
74+ 'acc_test: {:.4f}' .format (acc_test .item ()),
75+ 'f1_score: {:.4f}' .format (f1 .item ()))
76+
77+ def main (args ):
78+ if str .lower (args .dataset ) not in ['cora' ,'pubmed' ,'citeseer' ]:
79+ raise ValueError ("Unkown dataset: {}" .format (args .dataset ))
80+ dataset = Planetoid (args .dataset_path , args .dataset )
81+ graph = dataset [0 ]
82+ edge_index , _ = add_self_loops (graph .edge_index , num_nodes = graph .num_nodes )
83+ edge_weight = tlx .convert_to_tensor (calc_gcn_norm (edge_index , graph .num_nodes ))
84+
85+ train_idx = mask_to_index (graph .train_mask )
86+ test_idx = mask_to_index (graph .test_mask )
87+ val_idx = mask_to_index (graph .val_mask )
88+
89+ net = get_model (args , dataset )
90+
91+ loss = tlx .losses .softmax_cross_entropy_with_logits
92+ optimizer = tlx .optimizers .Adam (lr = args .lr , weight_decay = args .l2_coef )
93+ metrics = tlx .metrics .Accuracy ()
94+ train_weights = net .trainable_weights
95+
96+ cit_module = CITModule (clusters = args .clusters , p = args .p )
97+ mlp = AssignmentMatricsMLP (input_dim = args .hidden_dim , num_clusters = args .clusters , activation = 'relu' )
98+ adj_matrix = edge_index_to_csr_matrix (edge_index , graph .num_nodes )
99+ adj_matrix = tlx .convert_to_tensor (adj_matrix .toarray (), dtype = tlx .float32 )
100+
101+ loss_func = SemiSpvzLoss (net , loss , cit_module , mlp , adj_matrix )
102+ # loss_func = SemiSpvzLoss(net, loss)
103+ train_one_step = TrainOneStep (loss_func , optimizer , train_weights )
104+
105+ data = {
106+ "x" : graph .x ,
107+ "y" : graph .y ,
108+ "edge_index" : edge_index ,
109+ "edge_weight" : edge_weight ,
110+ "train_idx" : train_idx ,
111+ "test_idx" : test_idx ,
112+ "val_idx" : val_idx ,
113+ "num_nodes" : graph .num_nodes ,
114+ }
115+
116+ best_val_acc = 0
117+ for epoch in range (args .epochtimes ):
118+ t = time .time ()
119+ net .set_train ()
120+ train_loss = train_one_step (data , graph .y )
121+ net .set_eval ()
122+
123+ if isinstance (net , GATModel ):
124+ logits = log_softmax (net (data ['x' ], data ['edge_index' ], data ["num_nodes" ]))
125+ elif isinstance (net , GCNIIModel ):
126+ logits = log_softmax (net (data ['x' ], data ['edge_index' ], data ["edge_weight" ], data ["num_nodes" ]))
127+ elif isinstance (net , GCNModel ):
128+ logits = log_softmax (net (data ['x' ], data ['edge_index' ], None , data ["num_nodes" ]))
129+
130+ val_logits = tlx .gather (logits , data ['val_idx' ])
131+ val_y = tlx .gather (data ['y' ], data ['val_idx' ])
132+ acc_val = calculate_acc (val_logits , val_y , metrics )
133+
134+ print ('Epoch: {:04d}' .format (epoch + 1 ),
135+ 'train_loss: {:.4f}' .format (train_loss .item ()),
136+ 'acc_val: {:.4f}' .format (acc_val .item ()),
137+ 'time: {:.4f}s' .format (time .time () - t ))
138+
139+ if acc_val > best_val_acc :
140+ best_val_acc = acc_val
141+ net .save_weights (args .best_model_path + net .name + ".npz" , format = 'npz_dict' )
142+
143+ test (net , data , args , metrics )
144+
145+
146+ if __name__ == '__main__' :
147+ #set argument
148+ parser = argparse .ArgumentParser ()
149+ parser .add_argument ("--gnn" , type = str , default = "gcn" )
150+ parser .add_argument ("--lr" , type = float , default = 0.01 , help = "learning rate" )
151+ parser .add_argument ("--dataset" , type = str , default = "cora" , help = "dataset" )
152+ parser .add_argument ("--droprate" , type = float , default = 0.4 )
153+ parser .add_argument ("--p" , type = float , default = "0.2" )
154+ parser .add_argument ("--epochtimes" , type = int , default = 400 )
155+ parser .add_argument ("--clusters" , type = int , default = 100 )
156+ parser .add_argument ("--hidden_dim" , type = int , default = 8 )
157+ parser .add_argument ("--gpu" , type = int , default = - 1 )
158+ parser .add_argument ("--dataset_path" , type = str , default = r'' )
159+ parser .add_argument ("--ss" , type = float , default = 0.5 , help = "structure shift" )
160+ parser .add_argument ("--l2_coef" , type = float , default = 5e-4 )
161+ parser .add_argument ("--best_model_path" , type = str , default = r'./' , help = "path to save best model" )
162+
163+ args = parser .parse_args ()
164+ if args .gpu >= 0 :
165+ tlx .set_device ("GPU" , args .gpu )
166+ else :
167+ tlx .set_device ("CPU" )
168+
169+ main (args )
0 commit comments