Skip to content

Commit d5c7a80

Browse files
Merge pull request #70 from CarloLucibello/cl/negative
improvements to link prediction
2 parents 683c8b7 + 51e7894 commit d5c7a80

File tree

14 files changed

+272
-53
lines changed

14 files changed

+272
-53
lines changed

docs/make.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ makedocs(;
1313
"Graphs" => "gnngraph.md",
1414
"Message Passing" => "messagepassing.md",
1515
"Model Building" => "models.md",
16+
"Datasets" => "datasets.md",
1617
"API Reference" =>
1718
[
1819
"GNNGraph" => "api/gnngraph.md",

docs/src/datasets.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# Datasets
2+
3+
GNN.jl doesn't come with its own datasets, but leverages those available in the julia (and non-julia) ecosytem. In particular, the [examples in the GNN.jl repository](https://github.com/CarloLucibello/GraphNeuralNetworks.jl/tree/master/examples) make use of the [MLDatasets.jl](https://github.com/JuliaML/MLDatasets.jl) package. There you will find common graph datasets sich as Cora, PubMed, and Citeseer.
4+
Also MLDatasets gives access to the [TUDataset](https://chrsmrrs.github.io/datasets/docs/datasets/) repository and its numerous datasets.

examples/link_prediction_pubmed.jl

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -49,23 +49,20 @@ function train(; kws...)
4949
### LOAD DATA
5050
data = Cora.dataset()
5151
# data = PubMed.dataset()
52-
g = GNNGraph(data.adjacency_list) |> device
52+
g = GNNGraph(data.adjacency_list)
53+
@info g
5354
@show is_bidirected(g)
55+
@show has_self_loops(g)
56+
@show has_multi_edges(g)
57+
@show mean(degree(g))
58+
isbidir = is_bidirected(g)
59+
60+
g = g |> device
5461
X = data.node_features |> device
5562

5663
#### SPLIT INTO NEGATIVE AND POSITIVE SAMPLES
57-
s, t = edge_index(g)
58-
eids = randperm(g.num_edges)
59-
test_size = round(Int, g.num_edges * 0.1)
60-
61-
test_pos_s, test_pos_t = s[eids[1:test_size]], t[eids[1:test_size]]
62-
test_pos_g = GNNGraph(test_pos_s, test_pos_t, num_nodes=g.num_nodes)
63-
64-
train_pos_s, train_pos_t = s[eids[test_size+1:end]], t[eids[test_size+1:end]]
65-
train_pos_g = GNNGraph(train_pos_s, train_pos_t, num_nodes=g.num_nodes)
66-
67-
test_neg_g = negative_sample(g, num_neg_edges=test_size)
68-
64+
train_pos_g, test_pos_g = rand_edge_split(g, 0.9)
65+
test_neg_g = negative_sample(g, num_neg_edges=test_pos_g.num_edges, bidirected=isbidir)
6966

7067
### DEFINE MODEL #########
7168
nin, nhidden = size(X,1), args.nhidden
@@ -82,24 +79,30 @@ function train(; kws...)
8279

8380
### LOSS FUNCTION ############
8481

85-
function loss(pos_g, neg_g = nothing)
82+
function loss(pos_g, neg_g = nothing; with_accuracy=false)
8683
h = model(X)
8784
if neg_g === nothing
8885
# We sample a negative graph at each training step
89-
neg_g = negative_sample(pos_g)
86+
neg_g = negative_sample(pos_g, bidirected=isbidir)
9087
end
9188
pos_score = pred(pos_g, h)
9289
neg_score = pred(neg_g, h)
9390
scores = [pos_score; neg_score]
9491
labels = [fill!(similar(pos_score), 1); fill!(similar(neg_score), 0)]
95-
return logitbinarycrossentropy(scores, labels)
92+
l = logitbinarycrossentropy(scores, labels)
93+
if with_accuracy
94+
acc = 0.5 * mean(pos_score .>= 0) + 0.5 * mean(neg_score .< 0)
95+
return l, acc
96+
else
97+
return l
98+
end
9699
end
97100

98101
### LOGGING FUNCTION
99102
function report(epoch)
100-
train_loss = loss(train_pos_g)
101-
test_loss = loss(test_pos_g, test_neg_g)
102-
println("Epoch: $epoch Train: $(train_loss) Test: $(test_loss)")
103+
train_loss, train_acc = loss(train_pos_g, with_accuracy=true)
104+
test_loss, test_acc = loss(test_pos_g, test_neg_g, with_accuracy=true)
105+
println("Epoch: $epoch $((; train_loss, train_acc)) $((; test_loss, test_acc))")
103106
end
104107

105108
### TRAINING

src/GNNGraphs/GNNGraphs.jl

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,26 +22,29 @@ export GNNGraph,
2222
graph_features
2323

2424
include("query.jl")
25-
export edge_index,
26-
adjacency_list,
27-
normalized_laplacian,
28-
scaled_laplacian,
25+
export adjacency_list,
26+
edge_index,
2927
graph_indicator,
28+
has_multi_edges,
3029
is_bidirected,
30+
normalized_laplacian,
31+
scaled_laplacian,
3132
# from Graphs
3233
adjacency_matrix,
3334
degree,
34-
outneighbors,
35-
inneighbors
35+
has_self_loops,
36+
inneighbors,
37+
outneighbors
3638

3739
include("transform.jl")
3840
export add_nodes,
3941
add_edges,
40-
add_self_loops,
41-
remove_self_loops,
42-
remove_multi_edges,
42+
add_self_loops,
4343
getgraph,
4444
negative_sample,
45+
rand_edge_split,
46+
remove_self_loops,
47+
remove_multi_edges,
4548
# from Flux
4649
batch,
4750
unbatch,
@@ -51,6 +54,9 @@ export add_nodes,
5154
include("generate.jl")
5255
export rand_graph
5356

57+
include("operators.jl")
58+
# Base.intersect
59+
5460
include("convert.jl")
5561
include("utils.jl")
5662

src/GNNGraphs/convert.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@ function to_coo(coo::COO_T; dir=:out, num_nodes=nothing)
55
num_nodes = isnothing(num_nodes) ? max(maximum(s), maximum(t)) : num_nodes
66
@assert isnothing(val) || length(val) == length(s)
77
@assert length(s) == length(t)
8-
@assert min(minimum(s), minimum(t)) >= 1
9-
@assert max(maximum(s), maximum(t)) <= num_nodes
10-
8+
if !isempty(s)
9+
@assert min(minimum(s), minimum(t)) >= 1
10+
@assert max(maximum(s), maximum(t)) <= num_nodes
11+
end
1112
num_edges = length(s)
1213
return coo, num_nodes, num_edges
1314
end

src/GNNGraphs/gnngraph.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,11 @@ function GNNGraph(data;
150150
ndata, edata, gdata)
151151
end
152152

153+
function GNNGraph(n::T; graph_type=:coo, kws...) where {T<:Integer}
154+
s, t = T[], T[]
155+
return GNNGraph(s, t; graph_type, num_nodes=n, kws...)
156+
end
157+
153158
# COO convenience constructors
154159
GNNGraph(s::AbstractVector, t::AbstractVector, v = nothing; kws...) = GNNGraph((s, t, v); kws...)
155160
GNNGraph((s, t)::NTuple{2}; kws...) = GNNGraph((s, t, nothing); kws...)

src/GNNGraphs/operators.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# 2 or more args graph operators
2+
function Base.intersect(g1::GNNGraph, g2::GNNGraph)
3+
@assert g1.num_nodes == g2.num_nodes
4+
@assert graph_type_symbol(g1) == graph_type_symbol(g2)
5+
graph_type = graph_type_symbol(g1)
6+
num_nodes = g1.num_nodes
7+
8+
idx1, _ = edge_encoding(edge_index(g1)..., num_nodes)
9+
idx2, _ = edge_encoding(edge_index(g2)..., num_nodes)
10+
idx = intersect(idx1, idx2)
11+
s, t = edge_decoding(idx, num_nodes)
12+
return GNNGraph(s, t; num_nodes, graph_type)
13+
end

src/GNNGraphs/query.jl

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@ end
2828

2929
Graphs.has_edge(g::GNNGraph{<:ADJMAT_T}, i::Integer, j::Integer) = g.graph[i,j] != 0
3030

31+
graph_type_symbol(g::GNNGraph{<:COO_T}) = :coo
32+
graph_type_symbol(g::GNNGraph{<:SPARSE_T}) = :sparse
33+
graph_type_symbol(g::GNNGraph{<:ADJMAT_T}) = :dense
34+
3135
Graphs.nv(g::GNNGraph) = g.num_nodes
3236
Graphs.ne(g::GNNGraph) = g.num_edges
3337
Graphs.has_vertex(g::GNNGraph, i::Int) = 1 <= i <= g.num_nodes
@@ -243,10 +247,35 @@ function is_bidirected(g::GNNGraph)
243247
all((s1 .== s2) .& (t1 .== t2))
244248
end
245249

246-
@non_differentiable normalized_laplacian(x...)
247-
@non_differentiable normalized_adjacency(x...)
248-
@non_differentiable scaled_laplacian(x...)
249-
@non_differentiable adjacency_matrix(x...)
250+
"""
251+
has_self_loops(g::GNNGraph)
252+
253+
Return `true` if `g` has any self loops.
254+
"""
255+
function Graphs.has_self_loops(g::GNNGraph)
256+
s, t = edge_index(g)
257+
any(s .== t)
258+
end
259+
260+
"""
261+
has_multi_edges(g::GNNGraph)
262+
263+
Return `true` if `g` has any multiple edges.
264+
"""
265+
function has_multi_edges(g::GNNGraph)
266+
s, t = edge_index(g)
267+
idxs = edge_encoding(s, t, g.num_nodes)
268+
length(union(idxs)) < length(idxs)
269+
end
270+
271+
250272
@non_differentiable adjacency_list(x...)
273+
@non_differentiable adjacency_matrix(x...)
251274
@non_differentiable degree(x...)
252275
@non_differentiable graph_indicator(x...)
276+
@non_differentiable has_multi_edges(x...)
277+
@non_differentiable Graphs.has_self_loops(x...)
278+
@non_differentiable is_bidirected(x...)
279+
@non_differentiable normalized_adjacency(x...)
280+
@non_differentiable normalized_laplacian(x...)
281+
@non_differentiable scaled_laplacian(x...)

src/GNNGraphs/transform.jl

Lines changed: 48 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -321,13 +321,21 @@ function getgraph(g::GNNGraph, i::AbstractVector{Int}; nmap=false)
321321
end
322322

323323
"""
324-
negative_sample(g::GNNGraph; num_neg_edges=g.num_edges)
324+
negative_sample(g::GNNGraph;
325+
num_neg_edges = g.num_edges,
326+
bidirected = is_bidirected(g))
325327
326328
Return a graph containing random negative edges (i.e. non-edges) from graph `g` as edges.
329+
330+
Is `bidirected=true`, the output graph will be bidirected and there will be no
331+
leakage from the origin graph.
332+
333+
See also [`is_bidirected`](@ref).
327334
"""
328335
function negative_sample(g::GNNGraph;
329336
max_trials=3,
330-
num_neg_edges=g.num_edges)
337+
num_neg_edges=g.num_edges,
338+
bidirected = is_bidirected(g))
331339

332340
@assert g.num_graphs == 1
333341
# Consider self-loops as positive edges
@@ -344,8 +352,12 @@ function negative_sample(g::GNNGraph;
344352
device = Flux.cpu
345353
end
346354
idx_pos, maxid = edge_encoding(s, t, n)
347-
348-
pneg = 1 - g.num_edges / maxid # prob of selecting negative edge
355+
if bidirected
356+
num_neg_edges = num_neg_edges ÷ 2
357+
pneg = 1 - g.num_edges / 2maxid # prob of selecting negative edge
358+
else
359+
pneg = 1 - g.num_edges / 2maxid # prob of selecting negative edge
360+
end
349361
# pneg * sample_prob * maxid == num_neg_edges
350362
sample_prob = min(1, num_neg_edges / (pneg * maxid) * 1.1)
351363
idx_neg = Int[]
@@ -359,26 +371,44 @@ function negative_sample(g::GNNGraph;
359371
end
360372
end
361373
s_neg, t_neg = edge_decoding(idx_neg, n)
374+
if bidirected
375+
s_neg, t_neg = [s_neg; t_neg], [t_neg; s_neg]
376+
end
362377
return GNNGraph(s_neg, t_neg, num_nodes=n) |> device
363378
end
364379

365-
# each edge is represented by a number in
366-
# 1:N^2
367-
function edge_encoding(s, t, n)
368-
idx = (s .- 1) .* n .+ t
369-
maxid = n^2
370-
return idx, maxid
371-
end
380+
"""
381+
rand_edge_split(g::GNNGraph, frac) -> g1, g2
382+
383+
Randomly partition the edges in `g` to from two graphs, `g1`
384+
and `g2`. Both will have the same number of nodes as `g`.
385+
`g1` will contain a fraction `frac` of the original edges,
386+
while `g2` wil contain the rest.
387+
Useful for train/test splits in link prediction tasks.
388+
"""
389+
function rand_edge_split(g::GNNGraph, frac)
390+
# TODO add bidirected version
391+
s, t = edge_index(g)
392+
eids = randperm(g.num_edges)
393+
size1 = round(Int, g.num_edges * frac)
394+
395+
s1, t1 = s[eids[1:size1]], t[eids[1:size1]]
396+
g1 = GNNGraph(s1, t1, num_nodes=g.num_nodes)
397+
398+
s, t = edge_index(g)
399+
eids = randperm(g.num_edges)
400+
size1 = round(Int, g.num_edges * frac)
401+
402+
s1, t1 = s[eids[1:size1]], t[eids[1:size1]]
403+
g1 = GNNGraph(s1, t1, num_nodes=g.num_nodes)
372404

373-
# each edge is represented by a number in
374-
# 1:N^2
375-
function edge_decoding(idx, n)
376-
# g = remove_self_loops(g)
377-
s = (idx .- 1) n .+ 1
378-
t = (idx .- 1) .% n .+ 1
379-
return s, t
405+
s2, t2 = s[eids[size1+1:end]], t[eids[size1+1:end]]
406+
g2 = GNNGraph(s2, t2, num_nodes=g.num_nodes)
407+
408+
return g1, g2
380409
end
381410

411+
382412
# """
383413
# Transform vector of cartesian indexes into a tuple of vectors containing integers.
384414
# """

src/GNNGraphs/utils.jl

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,59 @@ end
7070
ones_like(x::AbstractArray, T=eltype(x), sz=size(x)) = fill!(similar(x, T, sz), 1)
7171
ones_like(x::SparseMatrixCSC, T=eltype(x), sz=size(x)) = ones(T, sz)
7272
ones_like(x::CUMAT_T, T=eltype(x), sz=size(x)) = CUDA.ones(T, sz)
73+
74+
75+
# each edge is represented by a number in
76+
# 1:N^2
77+
function edge_encoding(s, t, n; directed=true)
78+
if directed
79+
# directed edges and self-loops allowed
80+
idx = (s .- 1) .* n .+ t
81+
maxid = n^2
82+
else
83+
# Undirected edges and self-loops allowed
84+
maxid = n * (n + 1) ÷ 2
85+
86+
mask = s .> t
87+
snew = copy(s)
88+
tnew = copy(t)
89+
snew[mask] .= t[mask]
90+
tnew[mask] .= s[mask]
91+
s, t = snew, tnew
92+
93+
# idx = ∑_{i',i'<i} ∑_{j',j'>=i'}^n 1 + ∑_{j',i<=j'<=j} 1
94+
# = ∑_{i',i'<i} ∑_{j',j'>=i'}^n 1 + (j - i + 1)
95+
# = ∑_{i',i'<i} (n - i' + 1) + (j - i + 1)
96+
# = (i - 1)*(2*(n+1)-i)÷2 + (j - i + 1)
97+
idx = @. (s-1)*(2*(n+1)-s)÷2 + (t-s+1)
98+
end
99+
return idx, maxid
100+
end
101+
102+
# each edge is represented by a number in
103+
# 1:N^2
104+
function edge_decoding(idx, n; directed=true)
105+
if directed
106+
# g = remove_self_loops(g)
107+
s = (idx .- 1) n .+ 1
108+
t = (idx .- 1) .% n .+ 1
109+
else
110+
# We replace j=n in
111+
# idx = (i - 1)*(2*(n+1)-i)÷2 + (j - i + 1)
112+
# and obtain
113+
# idx = (i - 1)*(2*(n+1)-i)÷2 + (n - i + 1)
114+
115+
# OR We replace j=i and obtain??
116+
# idx = (i - 1)*(2*(n+1)-i)÷2 + 1
117+
118+
# inverting we have
119+
s = @. ceil(Int, -sqrt((n + 1/2)^2 - 2*idx) + n + 1/2)
120+
t = @. idx - (s-1)*(2*(n+1)-s)÷2 - 1 + s
121+
# t = (idx .- 1) .% n .+ 1
122+
end
123+
return s, t
124+
end
125+
126+
@non_differentiable edge_encoding(x...)
127+
@non_differentiable edge_decoding(x...)
128+

0 commit comments

Comments
 (0)