Skip to content

Commit 651f761

Browse files
Merge pull request #61 from CarloLucibello/cl/link
add link prediction example
2 parents 00aade3 + 13468e3 commit 651f761

File tree

4 files changed

+147
-2
lines changed

4 files changed

+147
-2
lines changed

examples/link_prediction_pubmed.jl

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# An example of link prediction using negative and positive samples.
2+
# Ported from https://docs.dgl.ai/tutorials/blitz/4_link_predict.html#sphx-glr-tutorials-blitz-4-link-predict-py
3+
4+
using Flux
5+
# Link prediction task
6+
# https://arxiv.org/pdf/2102.12557.pdf
7+
8+
using Flux: onecold, onehotbatch
9+
using Flux.Losses: logitbinarycrossentropy
10+
using GraphNeuralNetworks
11+
using MLDatasets: PubMed, Cora
12+
using Statistics, Random, LinearAlgebra
13+
using CUDA
14+
# using MLJBase: AreaUnderCurve
15+
CUDA.allowscalar(false)
16+
17+
# arguments for the `train` function
18+
Base.@kwdef mutable struct Args
19+
η = 1f-3 # learning rate
20+
epochs = 200 # number of epochs
21+
seed = 17 # set seed > 0 for reproducibility
22+
usecuda = false # if true use cuda (if available)
23+
nhidden = 64 # dimension of hidden features
24+
infotime = 10 # report every `infotime` epochs
25+
end
26+
27+
struct DotPredictor end
28+
29+
function (::DotPredictor)(g, x)
30+
z = apply_edges((xi, xj, e) -> sum(xi .* xj, dims=1), g, xi=x, xj=x)
31+
return vec(z)
32+
end
33+
34+
using ChainRulesCore
35+
36+
function train(; kws...)
37+
# args = Args(; kws...)
38+
args = Args()
39+
40+
args.seed > 0 && Random.seed!(args.seed)
41+
42+
if args.usecuda && CUDA.functional()
43+
device = gpu
44+
args.seed > 0 && CUDA.seed!(args.seed)
45+
@info "Training on GPU"
46+
else
47+
device = cpu
48+
@info "Training on CPU"
49+
end
50+
51+
### LOAD DATA
52+
data = Cora.dataset()
53+
g = GNNGraph(data.adjacency_list) |> device
54+
X = data.node_features |> device
55+
56+
57+
#### SPLIT INTO NEGATIVE AND POSITIVE SAMPLES
58+
s, t = edge_index(g)
59+
eids = randperm(g.num_edges)
60+
test_size = round(Int, g.num_edges * 0.1)
61+
62+
test_pos_s, test_pos_t = s[eids[1:test_size]], t[eids[1:test_size]]
63+
test_pos_g = GNNGraph(test_pos_s, test_pos_t, num_nodes=g.num_nodes)
64+
65+
train_pos_s, train_pos_t = s[eids[test_size+1:end]], t[eids[test_size+1:end]]
66+
train_pos_g = GNNGraph(train_pos_s, train_pos_t, num_nodes=g.num_nodes)
67+
68+
test_neg_g = negative_sample(g, num_neg_edges=test_size)
69+
70+
### DEFINE MODEL #########
71+
nin, nhidden = size(X,1), args.nhidden
72+
73+
model = WithGraph(GNNChain(GCNConv(nin => nhidden, relu),
74+
GCNConv(nhidden => nhidden)),
75+
train_pos_g) |> device
76+
77+
pred = DotPredictor()
78+
79+
ps = Flux.params(model)
80+
opt = ADAM(args.η)
81+
82+
### LOSS FUNCTION ############
83+
84+
function loss(pos_g, neg_g = nothing)
85+
h = model(X)
86+
if neg_g === nothing
87+
# we sample a negative graph at each training step
88+
neg_g = negative_sample(pos_g)
89+
end
90+
pos_score = pred(pos_g, h)
91+
neg_score = pred(neg_g, h)
92+
scores = [pos_score; neg_score]
93+
labels = [fill!(similar(pos_score), 1); fill!(similar(neg_score), 0)]
94+
return logitbinarycrossentropy(scores, labels)
95+
end
96+
97+
# function accuracy(pos_g, neg_g)
98+
# h = model(train_pos_g, X)
99+
# pos_score = pred(pos_g, h)
100+
# neg_score = pred(neg_g, h)
101+
# scores = [pos_score; neg_score]
102+
# labels = [fill!(similar(pos_score), 1); fill!(similar(neg_score), 0)]
103+
# return logitbinarycrossentropy(scores, labels)
104+
# end
105+
106+
### LOGGING FUNCTION
107+
function report(epoch)
108+
train_loss = loss(train_pos_g)
109+
test_loss = loss(test_pos_g, test_neg_g)
110+
println("Epoch: $epoch Train: $(train_loss) Test: $(test_loss)")
111+
end
112+
113+
### TRAINING
114+
report(0)
115+
for epoch in 1:args.epochs
116+
gs = Flux.gradient(() -> loss(train_pos_g), ps)
117+
Flux.Optimise.update!(opt, ps, gs)
118+
epoch % args.infotime == 0 && report(epoch)
119+
end
120+
end
121+
122+
# train()

src/GNNGraphs/GNNGraphs.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ export edge_index, adjacency_list, normalized_laplacian, scaled_laplacian,
2323
graph_indicator
2424

2525
include("transform.jl")
26-
export add_nodes, add_edges, add_self_loops, remove_self_loops, getgraph
26+
export add_nodes, add_edges, add_self_loops, remove_self_loops, getgraph,
27+
negative_sample
2728

2829
include("generate.jl")
2930
export rand_graph

src/GNNGraphs/transform.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,5 +324,26 @@ function getgraph(g::GNNGraph, i::AbstractVector{Int}; nmap=false)
324324
end
325325
end
326326

327+
328+
"""
329+
negative_sample(g::GNNGraph; num_neg_edges=g.num_edges)
330+
331+
Return a graph containing random negative edges (i.e. non-edges) from graph `g`.
332+
"""
333+
function negative_sample(g::GNNGraph; num_neg_edges=g.num_edges)
334+
adj = adjacency_matrix(g)
335+
adj_neg = 1 .- adj - I
336+
neg_s, neg_t = ci2t(findall(adj_neg .> 0), 2)
337+
neg_eids = randperm(length(neg_s))[1:num_neg_edges]
338+
neg_s, neg_t = neg_s[neg_eids], neg_t[neg_eids]
339+
return GNNGraph(neg_s, neg_t, num_nodes=g.num_nodes)
340+
end
341+
342+
# """
343+
# Transform vector of cartesian indexes into a tuple of vectors containing integers.
344+
# """
345+
ci2t(ci::AbstractVector{<:CartesianIndex}, dims) = ntuple(i -> map(x -> x[i], ci), dims)
346+
347+
@non_differentiable negative_sample(x...)
327348
@non_differentiable add_self_loops(x...) # TODO this is wrong, since g carries feature arrays, needs rrule
328349
@non_differentiable remove_self_loops(x...) # TODO this is wrong, since g carries feature arrays, needs rrule

src/msgpass.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,8 @@ end
7777
## APPLY EDGES
7878

7979
"""
80-
apply_edges(f, xi, xj, e)
80+
apply_edges(f, g, xi, xj, e)
81+
apply_edges(f, g; [xi, xj, e])
8182
8283
Returns the message from node `j` to node `i` .
8384
In the message-passing scheme, the incoming messages

0 commit comments

Comments
 (0)