@@ -378,19 +378,24 @@ function negative_sample(g::GNNGraph;
378
378
end
379
379
380
380
"""
381
- rand_edge_split(g::GNNGraph, frac) -> g1, g2
381
+ rand_edge_split(g::GNNGraph, frac; bidirected=is_bidirected(g) ) -> g1, g2
382
382
383
383
Randomly partition the edges in `g` to from two graphs, `g1`
384
384
and `g2`. Both will have the same number of nodes as `g`.
385
385
`g1` will contain a fraction `frac` of the original edges,
386
386
while `g2` wil contain the rest.
387
+
388
+ If `bidirected = true` makes sure that an edge and its reverse go into the same split.
389
+
387
390
Useful for train/test splits in link prediction tasks.
388
391
"""
389
- function rand_edge_split (g:: GNNGraph , frac)
390
- # TODO add bidirected version
392
+ function rand_edge_split (g:: GNNGraph , frac; bidirected= is_bidirected (g))
391
393
s, t = edge_index (g)
392
- eids = randperm (g. num_edges)
393
- size1 = round (Int, g. num_edges * frac)
394
+ idx, idmax = edge_encoding (s, t, g. num_nodes, directed= ! bidirected)
395
+ uidx = union (idx) # So that multi-edges (and reverse edges in the bidir case) go in the same split
396
+ nu = length (uidx)
397
+ eids = randperm (nu)
398
+ size1 = round (Int, nu * frac)
394
399
395
400
s1, t1 = s[eids[1 : size1]], t[eids[1 : size1]]
396
401
g1 = GNNGraph (s1, t1, num_nodes= g. num_nodes)
0 commit comments