Skip to content

Commit 446aba5

Browse files
wip on rand_split_edge
1 parent f96f58e commit 446aba5

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

src/GNNGraphs/transform.jl

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -378,19 +378,24 @@ function negative_sample(g::GNNGraph;
378378
end
379379

380380
"""
381-
rand_edge_split(g::GNNGraph, frac) -> g1, g2
381+
rand_edge_split(g::GNNGraph, frac; bidirected=is_bidirected(g)) -> g1, g2
382382
383383
Randomly partition the edges in `g` to from two graphs, `g1`
384384
and `g2`. Both will have the same number of nodes as `g`.
385385
`g1` will contain a fraction `frac` of the original edges,
386386
while `g2` wil contain the rest.
387+
388+
If `bidirected = true` makes sure that an edge and its reverse go into the same split.
389+
387390
Useful for train/test splits in link prediction tasks.
388391
"""
389-
function rand_edge_split(g::GNNGraph, frac)
390-
# TODO add bidirected version
392+
function rand_edge_split(g::GNNGraph, frac; bidirected=is_bidirected(g))
391393
s, t = edge_index(g)
392-
eids = randperm(g.num_edges)
393-
size1 = round(Int, g.num_edges * frac)
394+
idx, idmax = edge_encoding(s, t, g.num_nodes, directed=!bidirected)
395+
uidx = union(idx) # So that multi-edges (and reverse edges in the bidir case) go in the same split
396+
nu = length(uidx)
397+
eids = randperm(nu)
398+
size1 = round(Int, nu * frac)
394399

395400
s1, t1 = s[eids[1:size1]], t[eids[1:size1]]
396401
g1 = GNNGraph(s1, t1, num_nodes=g.num_nodes)

0 commit comments

Comments
 (0)