1
+ # An example of link prediction using negative and positive samples.
2
+ # Ported from https://docs.dgl.ai/tutorials/blitz/4_link_predict.html#sphx-glr-tutorials-blitz-4-link-predict-py
3
+
4
+ using Flux
5
+ # Link prediction task
6
+ # https://arxiv.org/pdf/2102.12557.pdf
7
+
8
+ using Flux: onecold, onehotbatch
9
+ using Flux. Losses: logitbinarycrossentropy
10
+ using GraphNeuralNetworks
11
+ using MLDatasets: PubMed, Cora
12
+ using Statistics, Random, LinearAlgebra
13
+ using CUDA
14
+ # using MLJBase: AreaUnderCurve
15
+ CUDA. allowscalar (false )
16
+
17
+ # arguments for the `train` function
18
+ Base. @kwdef mutable struct Args
19
+ η = 1f-3 # learning rate
20
+ epochs = 200 # number of epochs
21
+ seed = 17 # set seed > 0 for reproducibility
22
+ usecuda = false # if true use cuda (if available)
23
+ nhidden = 64 # dimension of hidden features
24
+ infotime = 10 # report every `infotime` epochs
25
+ end
26
+
27
+ struct DotPredictor end
28
+
29
+ function (:: DotPredictor )(g, x)
30
+ z = apply_edges ((xi, xj, e) -> sum (xi .* xj, dims= 1 ), g, xi= x, xj= x)
31
+ return vec (z)
32
+ end
33
+
34
+ using ChainRulesCore
35
+
36
+ function train (; kws... )
37
+ # args = Args(; kws...)
38
+ args = Args ()
39
+
40
+ args. seed > 0 && Random. seed! (args. seed)
41
+
42
+ if args. usecuda && CUDA. functional ()
43
+ device = gpu
44
+ args. seed > 0 && CUDA. seed! (args. seed)
45
+ @info " Training on GPU"
46
+ else
47
+ device = cpu
48
+ @info " Training on CPU"
49
+ end
50
+
51
+ # ## LOAD DATA
52
+ data = Cora. dataset ()
53
+ g = GNNGraph (data. adjacency_list) |> device
54
+ X = data. node_features |> device
55
+
56
+
57
+ # ### SPLIT INTO NEGATIVE AND POSITIVE SAMPLES
58
+ s, t = edge_index (g)
59
+ eids = randperm (g. num_edges)
60
+ test_size = round (Int, g. num_edges * 0.1 )
61
+
62
+ test_pos_s, test_pos_t = s[eids[1 : test_size]], t[eids[1 : test_size]]
63
+ test_pos_g = GNNGraph (test_pos_s, test_pos_t, num_nodes= g. num_nodes)
64
+
65
+ train_pos_s, train_pos_t = s[eids[test_size+ 1 : end ]], t[eids[test_size+ 1 : end ]]
66
+ train_pos_g = GNNGraph (train_pos_s, train_pos_t, num_nodes= g. num_nodes)
67
+
68
+ test_neg_g = negative_sample (g, num_neg_edges= test_size)
69
+
70
+ # ## DEFINE MODEL #########
71
+ nin, nhidden = size (X,1 ), args. nhidden
72
+
73
+ model = WithGraph (GNNChain (GCNConv (nin => nhidden, relu),
74
+ GCNConv (nhidden => nhidden)),
75
+ train_pos_g) |> device
76
+
77
+ pred = DotPredictor ()
78
+
79
+ ps = Flux. params (model)
80
+ opt = ADAM (args. η)
81
+
82
+ # ## LOSS FUNCTION ############
83
+
84
+ function loss (pos_g, neg_g = nothing )
85
+ h = model (X)
86
+ if neg_g === nothing
87
+ # we sample a negative graph at each training step
88
+ neg_g = negative_sample (pos_g)
89
+ end
90
+ pos_score = pred (pos_g, h)
91
+ neg_score = pred (neg_g, h)
92
+ scores = [pos_score; neg_score]
93
+ labels = [fill! (similar (pos_score), 1 ); fill! (similar (neg_score), 0 )]
94
+ return logitbinarycrossentropy (scores, labels)
95
+ end
96
+
97
+ # function accuracy(pos_g, neg_g)
98
+ # h = model(train_pos_g, X)
99
+ # pos_score = pred(pos_g, h)
100
+ # neg_score = pred(neg_g, h)
101
+ # scores = [pos_score; neg_score]
102
+ # labels = [fill!(similar(pos_score), 1); fill!(similar(neg_score), 0)]
103
+ # return logitbinarycrossentropy(scores, labels)
104
+ # end
105
+
106
+ # ## LOGGING FUNCTION
107
+ function report (epoch)
108
+ train_loss = loss (train_pos_g)
109
+ test_loss = loss (test_pos_g, test_neg_g)
110
+ println (" Epoch: $epoch Train: $(train_loss) Test: $(test_loss) " )
111
+ end
112
+
113
+ # ## TRAINING
114
+ report (0 )
115
+ for epoch in 1 : args. epochs
116
+ gs = Flux. gradient (() -> loss (train_pos_g), ps)
117
+ Flux. Optimise. update! (opt, ps, gs)
118
+ epoch % args. infotime == 0 && report (epoch)
119
+ end
120
+ end
121
+
122
+ # train()
0 commit comments