Skip to content

Commit fe26097

Browse files
negative_sample
1 parent df4f1a0 commit fe26097

File tree

3 files changed

+25
-6
lines changed

3 files changed

+25
-6
lines changed

examples/link_prediction_pubmed.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,14 @@ function train(; kws...)
5555
@show has_self_loops(g)
5656
@show has_multi_edges(g)
5757
@show mean(degree(g))
58+
isbidir = is_bidirected(g)
5859

5960
g = g |> device
6061
X = data.node_features |> device
6162

6263
#### SPLIT INTO NEGATIVE AND POSITIVE SAMPLES
6364
train_pos_g, test_pos_g = rand_edge_split(g, 0.9)
64-
test_neg_g = negative_sample(g, num_neg_edges=test_pos_g.num_edges)
65+
test_neg_g = negative_sample(g, num_neg_edges=test_pos_g.num_edges, bidirected=isbidir)
6566

6667
### DEFINE MODEL #########
6768
nin, nhidden = size(X,1), args.nhidden
@@ -82,7 +83,7 @@ function train(; kws...)
8283
h = model(X)
8384
if neg_g === nothing
8485
# We sample a negative graph at each training step
85-
neg_g = negative_sample(pos_g)
86+
neg_g = negative_sample(pos_g, bidirected=isbidir)
8687
end
8788
pos_score = pred(pos_g, h)
8889
neg_score = pred(neg_g, h)

src/GNNGraphs/transform.jl

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -321,13 +321,21 @@ function getgraph(g::GNNGraph, i::AbstractVector{Int}; nmap=false)
321321
end
322322

323323
"""
324-
negative_sample(g::GNNGraph; num_neg_edges=g.num_edges)
324+
negative_sample(g::GNNGraph;
325+
num_neg_edges = g.num_edges,
326+
bidirected = is_bidirected(g))
325327
326328
Return a graph containing random negative edges (i.e. non-edges) from graph `g` as edges.
329+
330+
Is `bidirected=true`, the output graph will be bidirected and there will be no
331+
leakage from the origin graph.
332+
333+
See also [`is_bidirected`](@ref).
327334
"""
328335
function negative_sample(g::GNNGraph;
329336
max_trials=3,
330-
num_neg_edges=g.num_edges)
337+
num_neg_edges=g.num_edges,
338+
bidirected = is_bidirected(g))
331339

332340
@assert g.num_graphs == 1
333341
# Consider self-loops as positive edges
@@ -344,8 +352,12 @@ function negative_sample(g::GNNGraph;
344352
device = Flux.cpu
345353
end
346354
idx_pos, maxid = edge_encoding(s, t, n)
347-
348-
pneg = 1 - g.num_edges / maxid # prob of selecting negative edge
355+
if bidirected
356+
num_neg_edges = num_neg_edges ÷ 2
357+
pneg = 1 - g.num_edges / 2maxid # prob of selecting negative edge
358+
else
359+
pneg = 1 - g.num_edges / 2maxid # prob of selecting negative edge
360+
end
349361
# pneg * sample_prob * maxid == num_neg_edges
350362
sample_prob = min(1, num_neg_edges / (pneg * maxid) * 1.1)
351363
idx_neg = Int[]
@@ -359,6 +371,9 @@ function negative_sample(g::GNNGraph;
359371
end
360372
end
361373
s_neg, t_neg = edge_decoding(idx_neg, n)
374+
if bidirected
375+
s_neg, t_neg = [s_neg; t_neg], [t_neg; s_neg]
376+
end
362377
return GNNGraph(s_neg, t_neg, num_nodes=n) |> device
363378
end
364379

@@ -372,6 +387,7 @@ while `g2` wil contain the rest.
372387
Useful for train/test splits in link prediction tasks.
373388
"""
374389
function rand_edge_split(g::GNNGraph, frac)
390+
# TODO add bidirected version
375391
s, t = edge_index(g)
376392
eids = randperm(g.num_edges)
377393
size1 = round(Int, g.num_edges * frac)

test/GNNGraphs/utils.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,5 +27,7 @@
2727
sdec, tdec = GNNGraphs.edge_decoding(idx, n, directed=false)
2828
@test sdec == snew
2929
@test tdec == tnew
30+
31+
g = rand_graph(10, 30, bidirected=true)
3032
end
3133
end

0 commit comments

Comments
 (0)