1
+ # An example of link prediction using negative and positive samples.
2
+ # Ported from https://docs.dgl.ai/tutorials/blitz/4_link_predict.html#sphx-glr-tutorials-blitz-4-link-predict-py
3
+
4
+ using Flux
5
+ using Flux: onecold, onehotbatch
6
+ using Flux. Losses: logitbinarycrossentropy
7
+ using GraphNeuralNetworks
8
+ using GraphNeuralNetworks: ones_like, zeros_like
9
+ using MLDatasets: Cora
10
+ using Statistics, Random, LinearAlgebra
11
+ using CUDA
12
+ using MLJBase: AreaUnderCurve
13
+ CUDA. allowscalar (false )
14
+
15
+ """
16
+ Transform vector of cartesian indexes into a tuple of vectors containing integers.
17
+ """
18
+ ci2t (ci:: AbstractVector{<:CartesianIndex} , dims) = ntuple (i -> map (x -> x[i], ci), dims)
19
+
20
+ # arguments for the `train` function
21
+ Base. @kwdef mutable struct Args
22
+ η = 1f-3 # learning rate
23
+ epochs = 200 # number of epochs
24
+ seed = 17 # set seed > 0 for reproducibility
25
+ usecuda = false # if true use cuda (if available)
26
+ nhidden = 128 # dimension of hidden features
27
+ infotime = 10 # report every `infotime` epochs
28
+ end
29
+
30
+ struct DotPredictor end
31
+
32
+ function (:: DotPredictor )(g, x)
33
+ z = apply_edges ((xi, xj, e) -> sum (xi .* xj, dims= 1 ), g, xi= x, xj= x)
34
+ return vec (z)
35
+ end
36
+
37
+ function train (; kws... )
38
+ # args = Args(; kws...)
39
+ args = Args ()
40
+
41
+ args. seed > 0 && Random. seed! (args. seed)
42
+
43
+ if args. usecuda && CUDA. functional ()
44
+ device = gpu
45
+ args. seed > 0 && CUDA. seed! (args. seed)
46
+ @info " Training on GPU"
47
+ else
48
+ device = cpu
49
+ @info " Training on CPU"
50
+ end
51
+
52
+ # ## LOAD DATA
53
+ data = Cora. dataset ()
54
+ g = GNNGraph (data. adjacency_list) |> device
55
+ X = data. node_features |> device
56
+
57
+ # ### SPLIT INTO NEGATIVE AND POSITIVE SAMPLES
58
+ # Split edge set for training and testing
59
+ s, t = edge_index (g)
60
+ eids = randperm (g. num_edges)
61
+ test_size = round (Int, g. num_edges * 0.1 )
62
+ train_size = g. num_edges - test_size
63
+ test_pos_s, test_pos_t = s[eids[1 : test_size]], t[eids[1 : test_size]]
64
+ train_pos_s, train_pos_t = s[eids[test_size+ 1 : end ]], t[eids[test_size+ 1 : end ]]
65
+
66
+ # Find all negative edges and split them for training and testing
67
+ adj = adjacency_matrix (g)
68
+ adj_neg = 1 .- adj - I
69
+ neg_s, neg_t = ci2t (findall (adj_neg .> 0 ), 2 )
70
+
71
+ neg_eids = randperm (length (neg_s))[1 : g. num_edges]
72
+ test_neg_s, test_neg_t = neg_s[neg_eids[1 : test_size]], neg_t[neg_eids[1 : test_size]]
73
+ train_neg_s, train_neg_t = neg_s[neg_eids[test_size+ 1 : end ]], neg_t[neg_eids[test_size+ 1 : end ]]
74
+ # train_neg_s, train_neg_t = neg_s[neg_eids[train_size+1:end]], neg_t[neg_eids[train_size+1:end]]
75
+
76
+ train_pos_g = GNNGraph ((train_pos_s, train_pos_t), num_nodes= g. num_nodes)
77
+ train_neg_g = GNNGraph ((train_neg_s, train_neg_t), num_nodes= g. num_nodes)
78
+
79
+ test_pos_g = GNNGraph ((test_pos_s, test_pos_t), num_nodes= g. num_nodes)
80
+ test_neg_g = GNNGraph ((test_neg_s, test_neg_t), num_nodes= g. num_nodes)
81
+
82
+ @show train_pos_g test_pos_g train_neg_g test_neg_g
83
+
84
+ # ## DEFINE MODEL
85
+ nin, nhidden = size (X,1 ), args. nhidden
86
+
87
+ model = GNNChain (GCNConv (nin => nhidden, relu),
88
+ GCNConv (nhidden => nhidden)) |> device
89
+
90
+ pred = DotPredictor ()
91
+
92
+ ps = Flux. params (model)
93
+ opt = ADAM (args. η)
94
+
95
+ # ## LOSS FUNCTION
96
+
97
+ function loss (pos_g, neg_g)
98
+ h = model (train_pos_g, X)
99
+ pos_score = pred (pos_g, h)
100
+ neg_score = pred (neg_g, h)
101
+ scores = [pos_score; neg_score]
102
+ labels = [ones_like (pos_score); zeros_like (neg_score)]
103
+ return logitbinarycrossentropy (scores, labels)
104
+ end
105
+
106
+ function accuracy (pos_g, neg_g)
107
+ h = model (train_pos_g, X)
108
+ pos_score = pred (pos_g, h)
109
+ neg_score = pred (neg_g, h)
110
+ scores = [pos_score; neg_score]
111
+ labels = [ones_like (pos_score); zeros_like (neg_score)]
112
+ return logitbinarycrossentropy (scores, labels)
113
+ end
114
+
115
+ # ## LOGGING FUNCTION
116
+ function report (epoch)
117
+ train_loss = loss (train_pos_g, train_neg_g)
118
+ test_loss = loss (test_pos_g, test_neg_g)
119
+ println (" Epoch: $epoch Train: $(train_loss) Test: $(test_loss) " )
120
+ end
121
+
122
+ # ## TRAINING
123
+ report (0 )
124
+ for epoch in 1 : args. epochs
125
+ gs = Flux. gradient (() -> loss (train_pos_g, train_neg_g), ps)
126
+ Flux. Optimise. update! (opt, ps, gs)
127
+ epoch % args. infotime == 0 && report (epoch)
128
+ end
129
+ end
130
+
131
+ # train()
0 commit comments