Skip to content

Commit 6c41b6d

Browse files
spanko-otwr
andauthored
CITGNN by twr (#221)
* CITGNN trainer, utils and readme * CITGNN trainer, utils and readme * CITGNN trainer, utils and readme * CITGNN trainer, utils and readme * CITGNN trainer, utils and readme * modified to fit multi backend * fitting multi-backend * add testing files --------- Co-authored-by: twr <[email protected]>
1 parent c9d2183 commit 6c41b6d

File tree

9 files changed

+422
-0
lines changed

9 files changed

+422
-0
lines changed

examples/citgnn/citgnn_trainer.py

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
import os
2+
os.environ['TL_BACKEND'] = 'torch'
3+
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
4+
# 0:Output all; 1:Filter out INFO; 2:Filter out INFO and WARNING; 3:Filter out INFO, WARNING, and ERROR
5+
6+
import argparse
7+
import tensorlayerx as tlx
8+
from gammagl.datasets import Planetoid
9+
from gammagl.utils import add_self_loops, mask_to_index, calc_gcn_norm
10+
import time
11+
import scipy.sparse as sp
12+
from gammagl.models import GCNModel, GCNIIModel, GATModel
13+
import tensorlayerx.nn as nn
14+
from tensorlayerx.model import WithLoss, TrainOneStep
15+
from utils import CITModule, dense_mincut_pool, edge_index_to_csr_matrix, AssignmentMatricsMLP, F1score, dense_to_sparse, calculate_acc, get_model
16+
17+
log_softmax = nn.activation.LogSoftmax(dim=-1)
18+
19+
class SemiSpvzLoss(WithLoss):
20+
def __init__(self, net, loss_fn, cit_module, mlp, adj_matrix):
21+
super(SemiSpvzLoss, self).__init__(backbone=net, loss_fn=loss_fn)
22+
self.cit_module = cit_module
23+
self.mlp = mlp
24+
self.adj_matrix = adj_matrix
25+
26+
def forward(self, data, label):
27+
if isinstance(self.backbone_network, GCNModel):
28+
intermediate_features = self.backbone_network.conv[0](data['x'], data['edge_index'], None, data["num_nodes"])
29+
intermediate_features = self.backbone_network.relu(intermediate_features)
30+
logits= log_softmax(self.backbone_network(data['x'], data['edge_index'], None, data["num_nodes"]))
31+
elif isinstance(self.backbone_network, GATModel):
32+
intermediate_features = log_softmax(self.backbone_network(data['x'], data['edge_index'], data["num_nodes"]))
33+
logits= log_softmax(self.backbone_network(data['x'], data['edge_index'], data["num_nodes"]))
34+
elif isinstance(self.backbone_network, GCNIIModel):
35+
intermediate_features = self.backbone_network(data['x'], data['edge_index'], data["edge_weight"], data["num_nodes"])
36+
logits= log_softmax(self.backbone_network(data['x'], data['edge_index'], data['edge_weight'], data["num_nodes"]))
37+
38+
if not isinstance(self.backbone_network, GCNModel):
39+
intermediate_features = nn.Linear(in_features=intermediate_features.shape[1], out_features=args.hidden_dim, name='linear_1', act=tlx.relu)(intermediate_features)
40+
41+
train_logits = tlx.gather(logits, data['train_idx'])
42+
train_y = tlx.gather(data['y'], data['train_idx'])
43+
44+
#CIT
45+
assignment_matrics, _ = self.cit_module.forward(intermediate_features, self.mlp)
46+
pooled_x, pooled_adj, mc_loss, o_loss = dense_mincut_pool(data['x'], self.adj_matrix, assignment_matrics)
47+
loss = 0.55 * self._loss_fn(train_logits, train_y) + 0.25 * mc_loss + 0.2 * o_loss
48+
# loss = self._loss_fn(train_logits, train_y)
49+
return loss
50+
51+
def test(net, data, args, metrics):
52+
net.load_weights(args.best_model_path+net.name+".npz", format='npz_dict')
53+
net.set_eval()
54+
55+
adj_add = sp.load_npz(f"./datasets/{args.dataset}_add_{args.ss}.npz")
56+
adj_add = dense_to_sparse(adj_add)
57+
58+
edge_weight = tlx.convert_to_tensor(calc_gcn_norm(adj_add, data["num_nodes"]))
59+
60+
if isinstance(net, GATModel):
61+
logits= log_softmax(net(data['x'], adj_add, data["num_nodes"]))
62+
elif isinstance(net, GCNModel):
63+
logits= log_softmax(net(data['x'], adj_add, None, data["num_nodes"]))
64+
elif isinstance(net, GCNIIModel):
65+
logits = log_softmax(net(data['x'], adj_add, edge_weight, data['num_nodes']))
66+
67+
test_logits = tlx.gather(logits, data['test_idx'])
68+
test_y = tlx.gather(data['y'], data['test_idx'])
69+
70+
acc_test = calculate_acc(test_logits, test_y, metrics)
71+
f1 = F1score(test_logits, test_y)
72+
73+
print('Test set results:',
74+
'acc_test: {:.4f}'.format(acc_test.item()),
75+
'f1_score: {:.4f}'.format(f1.item()))
76+
77+
def main(args):
78+
if str.lower(args.dataset) not in ['cora','pubmed','citeseer']:
79+
raise ValueError("Unkown dataset: {}".format(args.dataset))
80+
dataset = Planetoid(args.dataset_path, args.dataset)
81+
graph = dataset[0]
82+
edge_index, _ = add_self_loops(graph.edge_index, num_nodes=graph.num_nodes)
83+
edge_weight = tlx.convert_to_tensor(calc_gcn_norm(edge_index, graph.num_nodes))
84+
85+
train_idx = mask_to_index(graph.train_mask)
86+
test_idx = mask_to_index(graph.test_mask)
87+
val_idx = mask_to_index(graph.val_mask)
88+
89+
net = get_model(args, dataset)
90+
91+
loss = tlx.losses.softmax_cross_entropy_with_logits
92+
optimizer = tlx.optimizers.Adam(lr=args.lr, weight_decay=args.l2_coef)
93+
metrics = tlx.metrics.Accuracy()
94+
train_weights = net.trainable_weights
95+
96+
cit_module =CITModule(clusters=args.clusters, p=args.p)
97+
mlp = AssignmentMatricsMLP(input_dim=args.hidden_dim, num_clusters=args.clusters, activation='relu')
98+
adj_matrix = edge_index_to_csr_matrix(edge_index, graph.num_nodes)
99+
adj_matrix = tlx.convert_to_tensor(adj_matrix.toarray(), dtype=tlx.float32)
100+
101+
loss_func = SemiSpvzLoss(net, loss, cit_module, mlp, adj_matrix)
102+
# loss_func = SemiSpvzLoss(net, loss)
103+
train_one_step = TrainOneStep(loss_func, optimizer, train_weights)
104+
105+
data = {
106+
"x": graph.x,
107+
"y": graph.y,
108+
"edge_index": edge_index,
109+
"edge_weight": edge_weight,
110+
"train_idx": train_idx,
111+
"test_idx": test_idx,
112+
"val_idx": val_idx,
113+
"num_nodes": graph.num_nodes,
114+
}
115+
116+
best_val_acc = 0
117+
for epoch in range(args.epochtimes):
118+
t = time.time()
119+
net.set_train()
120+
train_loss = train_one_step(data, graph.y)
121+
net.set_eval()
122+
123+
if isinstance(net, GATModel):
124+
logits= log_softmax(net(data['x'], data['edge_index'], data["num_nodes"]))
125+
elif isinstance(net, GCNIIModel):
126+
logits= log_softmax(net(data['x'], data['edge_index'], data["edge_weight"], data["num_nodes"]))
127+
elif isinstance(net, GCNModel):
128+
logits= log_softmax(net(data['x'], data['edge_index'], None, data["num_nodes"]))
129+
130+
val_logits = tlx.gather(logits, data['val_idx'])
131+
val_y = tlx.gather(data['y'], data['val_idx'])
132+
acc_val = calculate_acc(val_logits, val_y, metrics)
133+
134+
print('Epoch: {:04d}'.format(epoch+1),
135+
'train_loss: {:.4f}'.format(train_loss.item()),
136+
'acc_val: {:.4f}'.format(acc_val.item()),
137+
'time: {:.4f}s'.format(time.time() - t))
138+
139+
if acc_val > best_val_acc:
140+
best_val_acc = acc_val
141+
net.save_weights(args.best_model_path+net.name+".npz", format='npz_dict')
142+
143+
test(net, data, args, metrics)
144+
145+
146+
if __name__ == '__main__':
147+
#set argument
148+
parser = argparse.ArgumentParser()
149+
parser.add_argument("--gnn", type=str, default="gcn")
150+
parser.add_argument("--lr", type=float, default=0.01, help="learning rate")
151+
parser.add_argument("--dataset", type=str, default="cora", help="dataset")
152+
parser.add_argument("--droprate", type=float, default=0.4)
153+
parser.add_argument("--p", type=float, default="0.2")
154+
parser.add_argument("--epochtimes", type=int, default=400)
155+
parser.add_argument("--clusters", type=int, default=100)
156+
parser.add_argument("--hidden_dim", type=int, default=8)
157+
parser.add_argument("--gpu", type=int, default=-1)
158+
parser.add_argument("--dataset_path", type=str, default=r'')
159+
parser.add_argument("--ss", type=float, default=0.5, help="structure shift")
160+
parser.add_argument("--l2_coef", type=float, default=5e-4)
161+
parser.add_argument("--best_model_path", type=str, default=r'./', help="path to save best model")
162+
163+
args = parser.parse_args()
164+
if args.gpu >= 0:
165+
tlx.set_device("GPU", args.gpu)
166+
else:
167+
tlx.set_device("CPU")
168+
169+
main(args)
37.3 KB
Binary file not shown.
41.9 KB
Binary file not shown.
38.7 KB
Binary file not shown.
43.8 KB
Binary file not shown.
377 KB
Binary file not shown.
431 KB
Binary file not shown.

examples/citgnn/readme.md

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Learning Invariant Representations of Graph Neural Networks via Cluster Generalization (CITGNN)
2+
3+
- Paper link: [https://arxiv.org/pdf/2403.03599]
4+
- Author's code repo: https://github.com/BUPT-GAMMA/CITGNN
5+
6+
# Dataset Statics
7+
8+
| Dataset | # Nodes | # Edges | # Classes |
9+
| -------- | ------- | ------- | --------- |
10+
| Cora | 2,708 | 10,556 | 7 |
11+
| Citeseer | 3,327 | 9,228 | 6 |
12+
| Pubmed | 19,717 | 88,651 | 3 |
13+
14+
Refer to [Planetoid](https://gammagl.readthedocs.io/en/latest/api/gammagl.datasets.html#gammagl.datasets.Planetoid).
15+
16+
# Results
17+
18+
- Available dataset: "cora", "citeseer", "pubmed"
19+
- Available gnn: "GCN", "GAT", "GCNII"
20+
21+
```bash
22+
# available dataset: "cora", "citeseer", "pubmed"
23+
python citgnn_trainer.py --gnn gcn --dataset cora --lr 0.005 --l2_coef 0.01 --droprate 0.8
24+
python citgnn_trainer.py --gnn gcn --dataset citeseer --lr 0.01 --l2_coef 0.01 --droprate 0.7
25+
python citgnn_trainer.py --gnn gcn --dataset pubmed --lr 0.01 --l2_coef 0.002 --droprate 0.5
26+
27+
python citgnn_trainer.py --gnn gat --dataset cora --lr 0.005 --l2_coef 0.005 --droprate 0.5
28+
python citgnn_trainer.py --gnn gat --dataset citeseer --lr 0.01 --l2_coef 0.005 --droprate 0.5
29+
python citgnn_trainer.py --gnn gat --dataset pubmed --lr 0.01 --l2_coef 0.001 --droprate 0.2
30+
31+
python citgnn_trainer.py --gnn gcnii --dataset cora --lr 0.01 --l2_coef 0.001 --droprate 0.3
32+
python citgnn_trainer.py --gnn gcnii --dataset citeseer --lr 0.01 --l2_coef 0.001 --droprate 0.4
33+
python citgnn_trainer.py --gnn gcnii --dataset pubmed --lr 0.01 --l2_coef 0.001 --droprate 0.6
34+
```
35+
36+
ADD-0.5
37+
38+
| | Paper | | | | | |
39+
| --- | --- | --- | --- | --- | --- | --- |
40+
| Method | Cora | | Citeseer | | Pubmed | |
41+
| | Acc | Macro-f1 | Acc | Macro-f1 | Acc | Macro-f1 |
42+
| CIT-GCN | 76.98±0.49 | 75.88±0.44 | 67.65±0.44 | 64.42±0.10 | 73.76±0.40 | 72.94±0.30 |
43+
| CIT-GAT | 77.23±0.42 | 76.26±0.28 | 66.33±0.24 | 63.07±0.37 | 72.50±0.74 | 71.57±0.82 |
44+
| CIT-GCNII | 78.28±0.88 | 75.82±0.73 | 66.12±0.97 | 63.17±0.85 | 75.95±0.63 | 75.47±0.76 |
45+
46+
| | Our | | | | | |
47+
| --- | --- | --- | --- | --- | --- | --- |
48+
| Method | Cora | | Citeseer | | Pubmed | |
49+
| | Acc | Macro-f1 | Acc | Macro-f1 | Acc | Macro-f1 |
50+
| CIT-GCN | 77.52±1.08 | 76.49±0.49 | 65.78±0.91 | 62.60±1.17 | 72.42±0.25 | 71.65±0.44 |
51+
| CIT-GAT | 75.84±0.56 | 74.81±0.66 | 63.41±1.28 | 59.98±1.42 | 71.80±0.64 | 70.78±0.69 |
52+
| CIT-GCNII | 80.30±1.06 | 78.44±1.18 | 65.94±1.04 | 62.67±0.80 | 76.27±0.49 | 75.30±0.67 |
53+

0 commit comments

Comments
 (0)