Skip to content

Commit 81fb4d1

Browse files
implement negative sampling
1 parent 53f79cd commit 81fb4d1

File tree

2 files changed

+37
-45
lines changed

2 files changed

+37
-45
lines changed

examples/Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
[deps]
22
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
3+
DiffEqFlux = "aae7a2af-3d4f-5e19-a356-7da93b79d9d0"
4+
DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa"
35
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
46
GraphNeuralNetworks = "cffab07f-9bc2-4db1-8861-388f63bf7694"
57
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
68
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
7-
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
89
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
910
NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d"

examples/link_prediction_cora.jl renamed to examples/link_prediction_pubmed.jl

Lines changed: 35 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,18 @@
22
# Ported from https://docs.dgl.ai/tutorials/blitz/4_link_predict.html#sphx-glr-tutorials-blitz-4-link-predict-py
33

44
using Flux
5+
# Link prediction task
6+
# https://arxiv.org/pdf/2102.12557.pdf
7+
58
using Flux: onecold, onehotbatch
69
using Flux.Losses: logitbinarycrossentropy
710
using GraphNeuralNetworks
8-
using GraphNeuralNetworks: ones_like, zeros_like
9-
using MLDatasets: Cora
11+
using MLDatasets: PubMed, Cora
1012
using Statistics, Random, LinearAlgebra
1113
using CUDA
12-
using MLJBase: AreaUnderCurve
14+
# using MLJBase: AreaUnderCurve
1315
CUDA.allowscalar(false)
1416

15-
"""
16-
Transform vector of cartesian indexes into a tuple of vectors containing integers.
17-
"""
18-
ci2t(ci::AbstractVector{<:CartesianIndex}, dims) = ntuple(i -> map(x -> x[i], ci), dims)
19-
2017
# arguments for the `train` function
2118
Base.@kwdef mutable struct Args
2219
η = 1f-3 # learning rate
@@ -34,6 +31,8 @@ function (::DotPredictor)(g, x)
3431
return vec(z)
3532
end
3633

34+
using ChainRulesCore
35+
3736
function train(; kws...)
3837
# args = Args(; kws...)
3938
args = Args()
@@ -54,75 +53,67 @@ function train(; kws...)
5453
g = GNNGraph(data.adjacency_list) |> device
5554
X = data.node_features |> device
5655

56+
5757
#### SPLIT INTO NEGATIVE AND POSITIVE SAMPLES
58-
# Split edge set for training and testing
5958
s, t = edge_index(g)
6059
eids = randperm(g.num_edges)
6160
test_size = round(Int, g.num_edges * 0.1)
62-
train_size = g.num_edges - test_size
61+
6362
test_pos_s, test_pos_t = s[eids[1:test_size]], t[eids[1:test_size]]
64-
train_pos_s, train_pos_t = s[eids[test_size+1:end]], t[eids[test_size+1:end]]
65-
66-
# Find all negative edges and split them for training and testing
67-
adj = adjacency_matrix(g)
68-
adj_neg = 1 .- adj - I
69-
neg_s, neg_t = ci2t(findall(adj_neg .> 0), 2)
70-
71-
neg_eids = randperm(length(neg_s))[1:g.num_edges]
72-
test_neg_s, test_neg_t = neg_s[neg_eids[1:test_size]], neg_t[neg_eids[1:test_size]]
73-
train_neg_s, train_neg_t = neg_s[neg_eids[test_size+1:end]], neg_t[neg_eids[test_size+1:end]]
74-
# train_neg_s, train_neg_t = neg_s[neg_eids[train_size+1:end]], neg_t[neg_eids[train_size+1:end]]
63+
test_pos_g = GNNGraph(test_pos_s, test_pos_t, num_nodes=g.num_nodes)
7564

76-
train_pos_g = GNNGraph((train_pos_s, train_pos_t), num_nodes=g.num_nodes)
77-
train_neg_g = GNNGraph((train_neg_s, train_neg_t), num_nodes=g.num_nodes)
65+
train_pos_s, train_pos_t = s[eids[test_size+1:end]], t[eids[test_size+1:end]]
66+
train_pos_g = GNNGraph(train_pos_s, train_pos_t, num_nodes=g.num_nodes)
7867

79-
test_pos_g = GNNGraph((test_pos_s, test_pos_t), num_nodes=g.num_nodes)
80-
test_neg_g = GNNGraph((test_neg_s, test_neg_t), num_nodes=g.num_nodes)
68+
test_neg_g = negative_sample(g, num_neg_edges=test_size)
8169

82-
@show train_pos_g test_pos_g train_neg_g test_neg_g
83-
84-
### DEFINE MODEL
70+
### DEFINE MODEL #########
8571
nin, nhidden = size(X,1), args.nhidden
8672

87-
model = GNNChain(GCNConv(nin => nhidden, relu),
88-
GCNConv(nhidden => nhidden)) |> device
73+
model = WithGraph(GNNChain(GCNConv(nin => nhidden, relu),
74+
GCNConv(nhidden => nhidden)),
75+
train_pos_g) |> device
8976

9077
pred = DotPredictor()
9178

9279
ps = Flux.params(model)
9380
opt = ADAM(args.η)
9481

95-
### LOSS FUNCTION
82+
### LOSS FUNCTION ############
9683

97-
function loss(pos_g, neg_g)
98-
h = model(train_pos_g, X)
84+
function loss(pos_g, neg_g = nothing)
85+
h = model(X)
86+
if neg_g === nothing
87+
# we sample a negative graph at each training step
88+
neg_g = negative_sample(pos_g)
89+
end
9990
pos_score = pred(pos_g, h)
10091
neg_score = pred(neg_g, h)
10192
scores = [pos_score; neg_score]
102-
labels = [ones_like(pos_score); zeros_like(neg_score)]
93+
labels = [fill!(similar(pos_score), 1); fill!(similar(neg_score), 0)]
10394
return logitbinarycrossentropy(scores, labels)
10495
end
10596

106-
function accuracy(pos_g, neg_g)
107-
h = model(train_pos_g, X)
108-
pos_score = pred(pos_g, h)
109-
neg_score = pred(neg_g, h)
110-
scores = [pos_score; neg_score]
111-
labels = [ones_like(pos_score); zeros_like(neg_score)]
112-
return logitbinarycrossentropy(scores, labels)
113-
end
97+
# function accuracy(pos_g, neg_g)
98+
# h = model(train_pos_g, X)
99+
# pos_score = pred(pos_g, h)
100+
# neg_score = pred(neg_g, h)
101+
# scores = [pos_score; neg_score]
102+
# labels = [fill!(similar(pos_score), 1); fill!(similar(neg_score), 0)]
103+
# return logitbinarycrossentropy(scores, labels)
104+
# end
114105

115106
### LOGGING FUNCTION
116107
function report(epoch)
117-
train_loss = loss(train_pos_g, train_neg_g)
108+
train_loss = loss(train_pos_g)
118109
test_loss = loss(test_pos_g, test_neg_g)
119110
println("Epoch: $epoch Train: $(train_loss) Test: $(test_loss)")
120111
end
121112

122113
### TRAINING
123114
report(0)
124115
for epoch in 1:args.epochs
125-
gs = Flux.gradient(() -> loss(train_pos_g, train_neg_g), ps)
116+
gs = Flux.gradient(() -> loss(train_pos_g), ps)
126117
Flux.Optimise.update!(opt, ps, gs)
127118
epoch % args.infotime == 0 && report(epoch)
128119
end

0 commit comments

Comments
 (0)