Skip to content

Commit 4c96da7

Browse files
add link prediction example
1 parent 00aade3 commit 4c96da7

File tree

2 files changed

+133
-1
lines changed

2 files changed

+133
-1
lines changed

examples/link_prediction_cora.jl

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
# An example of link prediction using negative and positive samples.
2+
# Ported from https://docs.dgl.ai/tutorials/blitz/4_link_predict.html#sphx-glr-tutorials-blitz-4-link-predict-py
3+
4+
using Flux
5+
using Flux: onecold, onehotbatch
6+
using Flux.Losses: logitbinarycrossentropy
7+
using GraphNeuralNetworks
8+
using GraphNeuralNetworks: ones_like, zeros_like
9+
using MLDatasets: Cora
10+
using Statistics, Random, LinearAlgebra
11+
using CUDA
12+
using MLJBase: AreaUnderCurve
13+
CUDA.allowscalar(false)
14+
15+
"""
16+
Transform vector of cartesian indexes into a tuple of vectors containing integers.
17+
"""
18+
ci2t(ci::AbstractVector{<:CartesianIndex}, dims) = ntuple(i -> map(x -> x[i], ci), dims)
19+
20+
# arguments for the `train` function
21+
Base.@kwdef mutable struct Args
22+
η = 1f-3 # learning rate
23+
epochs = 200 # number of epochs
24+
seed = 17 # set seed > 0 for reproducibility
25+
usecuda = false # if true use cuda (if available)
26+
nhidden = 128 # dimension of hidden features
27+
infotime = 10 # report every `infotime` epochs
28+
end
29+
30+
struct DotPredictor end
31+
32+
function (::DotPredictor)(g, x)
33+
z = apply_edges((xi, xj, e) -> sum(xi .* xj, dims=1), g, xi=x, xj=x)
34+
return vec(z)
35+
end
36+
37+
function train(; kws...)
38+
# args = Args(; kws...)
39+
args = Args()
40+
41+
args.seed > 0 && Random.seed!(args.seed)
42+
43+
if args.usecuda && CUDA.functional()
44+
device = gpu
45+
args.seed > 0 && CUDA.seed!(args.seed)
46+
@info "Training on GPU"
47+
else
48+
device = cpu
49+
@info "Training on CPU"
50+
end
51+
52+
### LOAD DATA
53+
data = Cora.dataset()
54+
g = GNNGraph(data.adjacency_list) |> device
55+
X = data.node_features |> device
56+
57+
#### SPLIT INTO NEGATIVE AND POSITIVE SAMPLES
58+
# Split edge set for training and testing
59+
s, t = edge_index(g)
60+
eids = randperm(g.num_edges)
61+
test_size = round(Int, g.num_edges * 0.1)
62+
train_size = g.num_edges - test_size
63+
test_pos_s, test_pos_t = s[eids[1:test_size]], t[eids[1:test_size]]
64+
train_pos_s, train_pos_t = s[eids[test_size+1:end]], t[eids[test_size+1:end]]
65+
66+
# Find all negative edges and split them for training and testing
67+
adj = adjacency_matrix(g)
68+
adj_neg = 1 .- adj - I
69+
neg_s, neg_t = ci2t(findall(adj_neg .> 0), 2)
70+
71+
neg_eids = randperm(length(neg_s))[1:g.num_edges]
72+
test_neg_s, test_neg_t = neg_s[neg_eids[1:test_size]], neg_t[neg_eids[1:test_size]]
73+
train_neg_s, train_neg_t = neg_s[neg_eids[test_size+1:end]], neg_t[neg_eids[test_size+1:end]]
74+
# train_neg_s, train_neg_t = neg_s[neg_eids[train_size+1:end]], neg_t[neg_eids[train_size+1:end]]
75+
76+
train_pos_g = GNNGraph((train_pos_s, train_pos_t), num_nodes=g.num_nodes)
77+
train_neg_g = GNNGraph((train_neg_s, train_neg_t), num_nodes=g.num_nodes)
78+
79+
test_pos_g = GNNGraph((test_pos_s, test_pos_t), num_nodes=g.num_nodes)
80+
test_neg_g = GNNGraph((test_neg_s, test_neg_t), num_nodes=g.num_nodes)
81+
82+
@show train_pos_g test_pos_g train_neg_g test_neg_g
83+
84+
### DEFINE MODEL
85+
nin, nhidden = size(X,1), args.nhidden
86+
87+
model = GNNChain(GCNConv(nin => nhidden, relu),
88+
GCNConv(nhidden => nhidden)) |> device
89+
90+
pred = DotPredictor()
91+
92+
ps = Flux.params(model)
93+
opt = ADAM(args.η)
94+
95+
### LOSS FUNCTION
96+
97+
function loss(pos_g, neg_g)
98+
h = model(train_pos_g, X)
99+
pos_score = pred(pos_g, h)
100+
neg_score = pred(neg_g, h)
101+
scores = [pos_score; neg_score]
102+
labels = [ones_like(pos_score); zeros_like(neg_score)]
103+
return logitbinarycrossentropy(scores, labels)
104+
end
105+
106+
function accuracy(pos_g, neg_g)
107+
h = model(train_pos_g, X)
108+
pos_score = pred(pos_g, h)
109+
neg_score = pred(neg_g, h)
110+
scores = [pos_score; neg_score]
111+
labels = [ones_like(pos_score); zeros_like(neg_score)]
112+
return logitbinarycrossentropy(scores, labels)
113+
end
114+
115+
### LOGGING FUNCTION
116+
function report(epoch)
117+
train_loss = loss(train_pos_g, train_neg_g)
118+
test_loss = loss(test_pos_g, test_neg_g)
119+
println("Epoch: $epoch Train: $(train_loss) Test: $(test_loss)")
120+
end
121+
122+
### TRAINING
123+
report(0)
124+
for epoch in 1:args.epochs
125+
gs = Flux.gradient(() -> loss(train_pos_g, train_neg_g), ps)
126+
Flux.Optimise.update!(opt, ps, gs)
127+
epoch % args.infotime == 0 && report(epoch)
128+
end
129+
end
130+
131+
# train()

src/msgpass.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,8 @@ end
7777
## APPLY EDGES
7878

7979
"""
80-
apply_edges(f, xi, xj, e)
80+
apply_edges(f, g, xi, xj, e)
81+
apply_edges(f, g; [xi, xj, e])
8182
8283
Returns the message from node `j` to node `i` .
8384
In the message-passing scheme, the incoming messages

0 commit comments

Comments
 (0)