2
2
# Ported from https://docs.dgl.ai/tutorials/blitz/4_link_predict.html#sphx-glr-tutorials-blitz-4-link-predict-py
3
3
4
4
using Flux
5
+ # Link prediction task
6
+ # https://arxiv.org/pdf/2102.12557.pdf
7
+
5
8
using Flux: onecold, onehotbatch
6
9
using Flux. Losses: logitbinarycrossentropy
7
10
using GraphNeuralNetworks
8
- using GraphNeuralNetworks: ones_like, zeros_like
9
- using MLDatasets: Cora
11
+ using MLDatasets: PubMed, Cora
10
12
using Statistics, Random, LinearAlgebra
11
13
using CUDA
12
- using MLJBase: AreaUnderCurve
14
+ # using MLJBase: AreaUnderCurve
13
15
CUDA. allowscalar (false )
14
16
15
- """
16
- Transform vector of cartesian indexes into a tuple of vectors containing integers.
17
- """
18
- ci2t (ci:: AbstractVector{<:CartesianIndex} , dims) = ntuple (i -> map (x -> x[i], ci), dims)
19
-
20
17
# arguments for the `train` function
21
18
Base. @kwdef mutable struct Args
22
19
η = 1f-3 # learning rate
@@ -34,6 +31,8 @@ function (::DotPredictor)(g, x)
34
31
return vec (z)
35
32
end
36
33
34
+ using ChainRulesCore
35
+
37
36
function train (; kws... )
38
37
# args = Args(; kws...)
39
38
args = Args ()
@@ -54,75 +53,67 @@ function train(; kws...)
54
53
g = GNNGraph (data. adjacency_list) |> device
55
54
X = data. node_features |> device
56
55
56
+
57
57
# ### SPLIT INTO NEGATIVE AND POSITIVE SAMPLES
58
- # Split edge set for training and testing
59
58
s, t = edge_index (g)
60
59
eids = randperm (g. num_edges)
61
60
test_size = round (Int, g. num_edges * 0.1 )
62
- train_size = g . num_edges - test_size
61
+
63
62
test_pos_s, test_pos_t = s[eids[1 : test_size]], t[eids[1 : test_size]]
64
- train_pos_s, train_pos_t = s[eids[test_size+ 1 : end ]], t[eids[test_size+ 1 : end ]]
65
-
66
- # Find all negative edges and split them for training and testing
67
- adj = adjacency_matrix (g)
68
- adj_neg = 1 .- adj - I
69
- neg_s, neg_t = ci2t (findall (adj_neg .> 0 ), 2 )
70
-
71
- neg_eids = randperm (length (neg_s))[1 : g. num_edges]
72
- test_neg_s, test_neg_t = neg_s[neg_eids[1 : test_size]], neg_t[neg_eids[1 : test_size]]
73
- train_neg_s, train_neg_t = neg_s[neg_eids[test_size+ 1 : end ]], neg_t[neg_eids[test_size+ 1 : end ]]
74
- # train_neg_s, train_neg_t = neg_s[neg_eids[train_size+1:end]], neg_t[neg_eids[train_size+1:end]]
63
+ test_pos_g = GNNGraph (test_pos_s, test_pos_t, num_nodes= g. num_nodes)
75
64
76
- train_pos_g = GNNGraph (( train_pos_s, train_pos_t), num_nodes = g . num_nodes)
77
- train_neg_g = GNNGraph ((train_neg_s, train_neg_t) , num_nodes= g. num_nodes)
65
+ train_pos_s, train_pos_t = s[eids[test_size + 1 : end ]], t[eids[test_size + 1 : end ]]
66
+ train_pos_g = GNNGraph (train_pos_s, train_pos_t , num_nodes= g. num_nodes)
78
67
79
- test_pos_g = GNNGraph ((test_pos_s, test_pos_t), num_nodes= g. num_nodes)
80
- test_neg_g = GNNGraph ((test_neg_s, test_neg_t), num_nodes= g. num_nodes)
68
+ test_neg_g = negative_sample (g, num_neg_edges= test_size)
81
69
82
- @show train_pos_g test_pos_g train_neg_g test_neg_g
83
-
84
- # ## DEFINE MODEL
70
+ # ## DEFINE MODEL #########
85
71
nin, nhidden = size (X,1 ), args. nhidden
86
72
87
- model = GNNChain (GCNConv (nin => nhidden, relu),
88
- GCNConv (nhidden => nhidden)) |> device
73
+ model = WithGraph (GNNChain (GCNConv (nin => nhidden, relu),
74
+ GCNConv (nhidden => nhidden)),
75
+ train_pos_g) |> device
89
76
90
77
pred = DotPredictor ()
91
78
92
79
ps = Flux. params (model)
93
80
opt = ADAM (args. η)
94
81
95
- # ## LOSS FUNCTION
82
+ # ## LOSS FUNCTION ############
96
83
97
- function loss (pos_g, neg_g)
98
- h = model (train_pos_g, X)
84
+ function loss (pos_g, neg_g = nothing )
85
+ h = model (X)
86
+ if neg_g === nothing
87
+ # we sample a negative graph at each training step
88
+ neg_g = negative_sample (pos_g)
89
+ end
99
90
pos_score = pred (pos_g, h)
100
91
neg_score = pred (neg_g, h)
101
92
scores = [pos_score; neg_score]
102
- labels = [ones_like ( pos_score); zeros_like ( neg_score)]
93
+ labels = [fill! ( similar ( pos_score), 1 ); fill! ( similar ( neg_score), 0 )]
103
94
return logitbinarycrossentropy (scores, labels)
104
95
end
105
96
106
- function accuracy (pos_g, neg_g)
107
- h = model (train_pos_g, X)
108
- pos_score = pred (pos_g, h)
109
- neg_score = pred (neg_g, h)
110
- scores = [pos_score; neg_score]
111
- labels = [ones_like ( pos_score); zeros_like ( neg_score)]
112
- return logitbinarycrossentropy (scores, labels)
113
- end
97
+ # function accuracy(pos_g, neg_g)
98
+ # h = model(train_pos_g, X)
99
+ # pos_score = pred(pos_g, h)
100
+ # neg_score = pred(neg_g, h)
101
+ # scores = [pos_score; neg_score]
102
+ # labels = [fill!(similar( pos_score), 1); fill!(similar( neg_score), 0 )]
103
+ # return logitbinarycrossentropy(scores, labels)
104
+ # end
114
105
115
106
# ## LOGGING FUNCTION
116
107
function report (epoch)
117
- train_loss = loss (train_pos_g, train_neg_g )
108
+ train_loss = loss (train_pos_g)
118
109
test_loss = loss (test_pos_g, test_neg_g)
119
110
println (" Epoch: $epoch Train: $(train_loss) Test: $(test_loss) " )
120
111
end
121
112
122
113
# ## TRAINING
123
114
report (0 )
124
115
for epoch in 1 : args. epochs
125
- gs = Flux. gradient (() -> loss (train_pos_g, train_neg_g ), ps)
116
+ gs = Flux. gradient (() -> loss (train_pos_g), ps)
126
117
Flux. Optimise. update! (opt, ps, gs)
127
118
epoch % args. infotime == 0 && report (epoch)
128
119
end
0 commit comments