Skip to content

Commit 13468e3

Browse files
implement negative sampling
1 parent 81fb4d1 commit 13468e3

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

src/GNNGraphs/GNNGraphs.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ export edge_index, adjacency_list, normalized_laplacian, scaled_laplacian,
2323
graph_indicator
2424

2525
include("transform.jl")
26-
export add_nodes, add_edges, add_self_loops, remove_self_loops, getgraph
26+
export add_nodes, add_edges, add_self_loops, remove_self_loops, getgraph,
27+
negative_sample
2728

2829
include("generate.jl")
2930
export rand_graph

src/GNNGraphs/transform.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,5 +324,26 @@ function getgraph(g::GNNGraph, i::AbstractVector{Int}; nmap=false)
324324
end
325325
end
326326

327+
328+
"""
329+
negative_sample(g::GNNGraph; num_neg_edges=g.num_edges)
330+
331+
Return a graph containing random negative edges (i.e. non-edges) from graph `g`.
332+
"""
333+
function negative_sample(g::GNNGraph; num_neg_edges=g.num_edges)
334+
adj = adjacency_matrix(g)
335+
adj_neg = 1 .- adj - I
336+
neg_s, neg_t = ci2t(findall(adj_neg .> 0), 2)
337+
neg_eids = randperm(length(neg_s))[1:num_neg_edges]
338+
neg_s, neg_t = neg_s[neg_eids], neg_t[neg_eids]
339+
return GNNGraph(neg_s, neg_t, num_nodes=g.num_nodes)
340+
end
341+
342+
# """
343+
# Transform vector of cartesian indexes into a tuple of vectors containing integers.
344+
# """
345+
ci2t(ci::AbstractVector{<:CartesianIndex}, dims) = ntuple(i -> map(x -> x[i], ci), dims)
346+
347+
@non_differentiable negative_sample(x...)
327348
@non_differentiable add_self_loops(x...) # TODO this is wrong, since g carries feature arrays, needs rrule
328349
@non_differentiable remove_self_loops(x...) # TODO this is wrong, since g carries feature arrays, needs rrule

0 commit comments

Comments
 (0)