Skip to content

Commit 2b28697

Browse files
Merge pull request #71 from CarloLucibello/cl/negative
bidirected graph support in rand_split_edge
2 parents d5c7a80 + a3cd35b commit 2b28697

File tree

4 files changed

+67
-25
lines changed

4 files changed

+67
-25
lines changed

examples/link_prediction_pubmed.jl

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,9 @@ using Flux
66
using Flux: onecold, onehotbatch
77
using Flux.Losses: logitbinarycrossentropy
88
using GraphNeuralNetworks
9-
using MLDatasets: PubMed, Cora
9+
using MLDatasets: PubMed
1010
using Statistics, Random, LinearAlgebra
1111
using CUDA
12-
# using MLJBase: AreaUnderCurve
1312
CUDA.allowscalar(false)
1413

1514
# arguments for the `train` function
@@ -28,7 +27,7 @@ struct DotPredictor end
2827

2928
function (::DotPredictor)(g, x)
3029
z = apply_edges((xi, xj, e) -> sum(xi .* xj, dims=1), g, xi=x, xj=x)
31-
# z = apply_edges(xi_dot_xj, g, xi=x, xj=x) # Same with buit-in methods
30+
# z = apply_edges(xi_dot_xj, g, xi=x, xj=x) # Same with built-in method
3231
return vec(z)
3332
end
3433

@@ -47,21 +46,25 @@ function train(; kws...)
4746
end
4847

4948
### LOAD DATA
50-
data = Cora.dataset()
51-
# data = PubMed.dataset()
49+
data = PubMed.dataset()
5250
g = GNNGraph(data.adjacency_list)
51+
52+
# Print some info
5353
@info g
5454
@show is_bidirected(g)
5555
@show has_self_loops(g)
5656
@show has_multi_edges(g)
5757
@show mean(degree(g))
5858
isbidir = is_bidirected(g)
5959

60+
# Move to device
6061
g = g |> device
6162
X = data.node_features |> device
6263

63-
#### SPLIT INTO NEGATIVE AND POSITIVE SAMPLES
64-
train_pos_g, test_pos_g = rand_edge_split(g, 0.9)
64+
#### TRAIN/TEST splits
65+
# With bidirected graph, we make sure that an edge and its reverse
66+
# are in the same split
67+
train_pos_g, test_pos_g = rand_edge_split(g, 0.9, bidirected=isbidir)
6568
test_neg_g = negative_sample(g, num_neg_edges=test_pos_g.num_edges, bidirected=isbidir)
6669

6770
### DEFINE MODEL #########

src/GNNGraphs/GNNGraphs.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using SparseArrays
44
using Functors: @functor
55
using CUDA
66
import Graphs
7-
using Graphs: AbstractGraph, outneighbors, inneighbors, adjacency_matrix, degree
7+
using Graphs: AbstractGraph, outneighbors, inneighbors, adjacency_matrix, degree, has_self_loops
88
import Flux
99
using Flux: batch
1010
import NNlib

src/GNNGraphs/transform.jl

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -378,33 +378,41 @@ function negative_sample(g::GNNGraph;
378378
end
379379

380380
"""
381-
rand_edge_split(g::GNNGraph, frac) -> g1, g2
381+
rand_edge_split(g::GNNGraph, frac; bidirected=is_bidirected(g)) -> g1, g2
382382
383383
Randomly partition the edges in `g` to from two graphs, `g1`
384384
and `g2`. Both will have the same number of nodes as `g`.
385385
`g1` will contain a fraction `frac` of the original edges,
386386
while `g2` wil contain the rest.
387-
Useful for train/test splits in link prediction tasks.
388-
"""
389-
function rand_edge_split(g::GNNGraph, frac)
390-
# TODO add bidirected version
391-
s, t = edge_index(g)
392-
eids = randperm(g.num_edges)
393-
size1 = round(Int, g.num_edges * frac)
394-
395-
s1, t1 = s[eids[1:size1]], t[eids[1:size1]]
396-
g1 = GNNGraph(s1, t1, num_nodes=g.num_nodes)
397387
388+
If `bidirected = true` makes sure that an edge and its reverse go into the same split.
389+
This option is supported only for bidirected graphs with no self-loops
390+
and multi-edges.
391+
392+
`rand_edge_split` is tipically used to create train/test splits in link prediction tasks.
393+
"""
394+
function rand_edge_split(g::GNNGraph, frac; bidirected=is_bidirected(g))
398395
s, t = edge_index(g)
399-
eids = randperm(g.num_edges)
400-
size1 = round(Int, g.num_edges * frac)
396+
ne = bidirected ? g.num_edges ÷ 2 : g.num_edges
397+
eids = randperm(ne)
398+
size1 = round(Int, ne * frac)
401399

402-
s1, t1 = s[eids[1:size1]], t[eids[1:size1]]
400+
if !bidirected
401+
s1, t1 = s[eids[1:size1]], t[eids[1:size1]]
402+
s2, t2 = s[eids[size1+1:end]], t[eids[size1+1:end]]
403+
else
404+
@assert is_bidirected(g)
405+
@assert !has_self_loops(g)
406+
@assert !has_multi_edges(g)
407+
mask = s .< t
408+
s, t = s[mask], t[mask]
409+
s1, t1 = s[eids[1:size1]], t[eids[1:size1]]
410+
s1, t1 = [s1; t1], [t1; s1]
411+
s2, t2 = s[eids[size1+1:end]], t[eids[size1+1:end]]
412+
s2, t2 = [s2; t2], [t2; s2]
413+
end
403414
g1 = GNNGraph(s1, t1, num_nodes=g.num_nodes)
404-
405-
s2, t2 = s[eids[size1+1:end]], t[eids[size1+1:end]]
406415
g2 = GNNGraph(s2, t2, num_nodes=g.num_nodes)
407-
408416
return g1, g2
409417
end
410418

test/GNNGraphs/transform.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,4 +140,35 @@
140140
@test intersect(g, gneg).num_edges == 0
141141
end
142142
end
143+
144+
@testset "rand_edge_split" begin
145+
if GRAPH_T == :coo
146+
n, m = 100,300
147+
148+
g = rand_graph(n, m, bidirected=true, graph_type=GRAPH_T)
149+
# check bidirected=is_bidirected(g) default
150+
g1, g2 = rand_edge_split(g, 0.9)
151+
@test is_bidirected(g1)
152+
@test is_bidirected(g2)
153+
@test intersect(g1, g2).num_edges == 0
154+
@test g1.num_edges + g2.num_edges == g.num_edges
155+
@test g2.num_edges < 50
156+
157+
g = rand_graph(n, m, bidirected=false, graph_type=GRAPH_T)
158+
# check bidirected=is_bidirected(g) default
159+
g1, g2 = rand_edge_split(g, 0.9)
160+
@test !is_bidirected(g1)
161+
@test !is_bidirected(g2)
162+
@test intersect(g1, g2).num_edges == 0
163+
@test g1.num_edges + g2.num_edges == g.num_edges
164+
@test g2.num_edges < 50
165+
166+
g1, g2 = rand_edge_split(g, 0.9, bidirected=false)
167+
@test !is_bidirected(g1)
168+
@test !is_bidirected(g2)
169+
@test intersect(g1, g2).num_edges == 0
170+
@test g1.num_edges + g2.num_edges == g.num_edges
171+
@test g2.num_edges < 50
172+
end
173+
end
143174
end

0 commit comments

Comments
 (0)