Skip to content

Commit a3cd35b

Browse files
rand edge split
1 parent 446aba5 commit a3cd35b

File tree

4 files changed

+54
-22
lines changed

4 files changed

+54
-22
lines changed

examples/link_prediction_pubmed.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,9 @@ using Flux
66
using Flux: onecold, onehotbatch
77
using Flux.Losses: logitbinarycrossentropy
88
using GraphNeuralNetworks
9-
using MLDatasets: PubMed, Cora
9+
using MLDatasets: PubMed
1010
using Statistics, Random, LinearAlgebra
1111
using CUDA
12-
# using MLJBase: AreaUnderCurve
1312
CUDA.allowscalar(false)
1413

1514
# arguments for the `train` function
@@ -47,8 +46,7 @@ function train(; kws...)
4746
end
4847

4948
### LOAD DATA
50-
data = Cora.dataset()
51-
# data = PubMed.dataset()
49+
data = PubMed.dataset()
5250
g = GNNGraph(data.adjacency_list)
5351

5452
# Print some info

src/GNNGraphs/GNNGraphs.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using SparseArrays
44
using Functors: @functor
55
using CUDA
66
import Graphs
7-
using Graphs: AbstractGraph, outneighbors, inneighbors, adjacency_matrix, degree
7+
using Graphs: AbstractGraph, outneighbors, inneighbors, adjacency_matrix, degree, has_self_loops
88
import Flux
99
using Flux: batch
1010
import NNlib

src/GNNGraphs/transform.jl

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -386,30 +386,33 @@ and `g2`. Both will have the same number of nodes as `g`.
386386
while `g2` wil contain the rest.
387387
388388
If `bidirected = true` makes sure that an edge and its reverse go into the same split.
389+
This option is supported only for bidirected graphs with no self-loops
390+
and multi-edges.
389391
390-
Useful for train/test splits in link prediction tasks.
392+
`rand_edge_split` is tipically used to create train/test splits in link prediction tasks.
391393
"""
392394
function rand_edge_split(g::GNNGraph, frac; bidirected=is_bidirected(g))
393395
s, t = edge_index(g)
394-
idx, idmax = edge_encoding(s, t, g.num_nodes, directed=!bidirected)
395-
uidx = union(idx) # So that multi-edges (and reverse edges in the bidir case) go in the same split
396-
nu = length(uidx)
397-
eids = randperm(nu)
398-
size1 = round(Int, nu * frac)
396+
ne = bidirected ? g.num_edges ÷ 2 : g.num_edges
397+
eids = randperm(ne)
398+
size1 = round(Int, ne * frac)
399399

400-
s1, t1 = s[eids[1:size1]], t[eids[1:size1]]
401-
g1 = GNNGraph(s1, t1, num_nodes=g.num_nodes)
402-
403-
s, t = edge_index(g)
404-
eids = randperm(g.num_edges)
405-
size1 = round(Int, g.num_edges * frac)
406-
407-
s1, t1 = s[eids[1:size1]], t[eids[1:size1]]
400+
if !bidirected
401+
s1, t1 = s[eids[1:size1]], t[eids[1:size1]]
402+
s2, t2 = s[eids[size1+1:end]], t[eids[size1+1:end]]
403+
else
404+
@assert is_bidirected(g)
405+
@assert !has_self_loops(g)
406+
@assert !has_multi_edges(g)
407+
mask = s .< t
408+
s, t = s[mask], t[mask]
409+
s1, t1 = s[eids[1:size1]], t[eids[1:size1]]
410+
s1, t1 = [s1; t1], [t1; s1]
411+
s2, t2 = s[eids[size1+1:end]], t[eids[size1+1:end]]
412+
s2, t2 = [s2; t2], [t2; s2]
413+
end
408414
g1 = GNNGraph(s1, t1, num_nodes=g.num_nodes)
409-
410-
s2, t2 = s[eids[size1+1:end]], t[eids[size1+1:end]]
411415
g2 = GNNGraph(s2, t2, num_nodes=g.num_nodes)
412-
413416
return g1, g2
414417
end
415418

test/GNNGraphs/transform.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,4 +140,35 @@
140140
@test intersect(g, gneg).num_edges == 0
141141
end
142142
end
143+
144+
@testset "rand_edge_split" begin
145+
if GRAPH_T == :coo
146+
n, m = 100,300
147+
148+
g = rand_graph(n, m, bidirected=true, graph_type=GRAPH_T)
149+
# check bidirected=is_bidirected(g) default
150+
g1, g2 = rand_edge_split(g, 0.9)
151+
@test is_bidirected(g1)
152+
@test is_bidirected(g2)
153+
@test intersect(g1, g2).num_edges == 0
154+
@test g1.num_edges + g2.num_edges == g.num_edges
155+
@test g2.num_edges < 50
156+
157+
g = rand_graph(n, m, bidirected=false, graph_type=GRAPH_T)
158+
# check bidirected=is_bidirected(g) default
159+
g1, g2 = rand_edge_split(g, 0.9)
160+
@test !is_bidirected(g1)
161+
@test !is_bidirected(g2)
162+
@test intersect(g1, g2).num_edges == 0
163+
@test g1.num_edges + g2.num_edges == g.num_edges
164+
@test g2.num_edges < 50
165+
166+
g1, g2 = rand_edge_split(g, 0.9, bidirected=false)
167+
@test !is_bidirected(g1)
168+
@test !is_bidirected(g2)
169+
@test intersect(g1, g2).num_edges == 0
170+
@test g1.num_edges + g2.num_edges == g.num_edges
171+
@test g2.num_edges < 50
172+
end
173+
end
143174
end

0 commit comments

Comments
 (0)