@@ -321,13 +321,21 @@ function getgraph(g::GNNGraph, i::AbstractVector{Int}; nmap=false)
321
321
end
322
322
323
323
"""
324
- negative_sample(g::GNNGraph; num_neg_edges=g.num_edges)
324
+ negative_sample(g::GNNGraph;
325
+ num_neg_edges = g.num_edges,
326
+ bidirected = is_bidirected(g))
325
327
326
328
Return a graph containing random negative edges (i.e. non-edges) from graph `g` as edges.
329
+
330
+ Is `bidirected=true`, the output graph will be bidirected and there will be no
331
+ leakage from the origin graph.
332
+
333
+ See also [`is_bidirected`](@ref).
327
334
"""
328
335
function negative_sample (g:: GNNGraph ;
329
336
max_trials= 3 ,
330
- num_neg_edges= g. num_edges)
337
+ num_neg_edges= g. num_edges,
338
+ bidirected = is_bidirected (g))
331
339
332
340
@assert g. num_graphs == 1
333
341
# Consider self-loops as positive edges
@@ -344,8 +352,12 @@ function negative_sample(g::GNNGraph;
344
352
device = Flux. cpu
345
353
end
346
354
idx_pos, maxid = edge_encoding (s, t, n)
347
-
348
- pneg = 1 - g. num_edges / maxid # prob of selecting negative edge
355
+ if bidirected
356
+ num_neg_edges = num_neg_edges ÷ 2
357
+ pneg = 1 - g. num_edges / 2 maxid # prob of selecting negative edge
358
+ else
359
+ pneg = 1 - g. num_edges / 2 maxid # prob of selecting negative edge
360
+ end
349
361
# pneg * sample_prob * maxid == num_neg_edges
350
362
sample_prob = min (1 , num_neg_edges / (pneg * maxid) * 1.1 )
351
363
idx_neg = Int[]
@@ -359,6 +371,9 @@ function negative_sample(g::GNNGraph;
359
371
end
360
372
end
361
373
s_neg, t_neg = edge_decoding (idx_neg, n)
374
+ if bidirected
375
+ s_neg, t_neg = [s_neg; t_neg], [t_neg; s_neg]
376
+ end
362
377
return GNNGraph (s_neg, t_neg, num_nodes= n) |> device
363
378
end
364
379
@@ -372,6 +387,7 @@ while `g2` wil contain the rest.
372
387
Useful for train/test splits in link prediction tasks.
373
388
"""
374
389
function rand_edge_split (g:: GNNGraph , frac)
390
+ # TODO add bidirected version
375
391
s, t = edge_index (g)
376
392
eids = randperm (g. num_edges)
377
393
size1 = round (Int, g. num_edges * frac)
0 commit comments