Skip to content
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/GraphNeuralNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ export
broadcast_nodes,
broadcast_edges,
softmax_edge_neighbors,
topk_nodes,
topk_edges,

# msgpass
apply_edges,
Expand Down
54 changes: 54 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,57 @@ function broadcast_edges(g::GNNGraph, x)
gi = graph_indicator(g, edges = true)
return gather(x, gi)
end

function _sort_col(matrix::AbstractArray; rev::Bool = true, sortby::Int = 1)
index = sortperm(view(matrix, sortby, :); rev)
return matrix[:, index]
end

function _sort_matrix(matrix::AbstractArray, k::Int; rev::Bool = true, sortby = nothing)
if sortby === nothing
return sort(matrix, dims = 2; rev)[:, 1:k]
else
return _sort_col(matrix; rev, sortby)[:, 1:k]
end
end

function _sort_batch(matrices, k::Int; rev::Bool = true, sortby = nothing)
return map(x -> _sort_matrix(x, k; rev, sortby), matrices)
end

function _topk_batch(matrix::AbstractArray, number_graphs::Int, k::Int; rev::Bool = true,
sortby = nothing)
tensor_matrix = reshape(matrix, size(matrix, 1), size(matrix, 2) ÷ number_graphs,
number_graphs)
sorted_matrix = _sort_batch(eachslice(tensor_matrix, dims = 3), k; rev, sortby)
return reduce(hcat, sorted_matrix)
end

function _topk(matrix::AbstractArray, number_graphs::Int, k::Int; rev::Bool = true,
sortby = nothing)
if number_graphs == 1
return _sort_matrix(matrix, k; rev, sortby)
else
return _topk_batch(matrix, number_graphs, k; rev, sortby)
end
end

"""
topk_nodes(g, feat, k; rev = true, sortby = nothing)

Graph-wise top-k on node features `feat` according to the `sortby` feature index.
"""
function topk_nodes(g::GNNGraph, feat::Symbol, k::Int; rev = true, sortby = nothing)
matrix = getproperty(g.ndata, feat)
return _topk(matrix, g.num_graphs, k; rev, sortby)
end

"""
topk_edges(g, feat, k; rev = true, sortby = nothing)

Graph-wise top-k on edge features `feat` according to the `sortby` feature index.
"""
function topk_edges(g::GNNGraph, feat::Symbol, k::Int; rev = true, sortby = nothing)
matrix = getproperty(g.edata, feat)
return _topk(matrix, g.num_graphs, k; rev, sortby)
end
35 changes: 35 additions & 0 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,4 +59,39 @@
@test z[:, 1:2] ≈ NNlib.softmax(e2[:, 1:2], dims = 2)
@test z[:, 3:4] ≈ NNlib.softmax(e2[:, 3:4], dims = 2)
end

@testset "topk_nodes" begin
A = [1.0 5.0 9.0; 2.0 6.0 10.0; 3.0 7.0 11.0; 4.0 8.0 12.0]
B = [0.318907 0.189981 0.991791;
0.547022 0.977349 0.680538;
0.921823 0.35132 0.494715;
0.451793 0.00704976 0.0189275]
g1 = rand_graph(3, 6, ndata = (x = A,))
g2 = rand_graph(3, 6, ndata = B)
output1 = topk_nodes(g1, :x, 2)
output2 = topk_nodes(g2, :x, 1, sortby = 2)
@test output1 == [9.0 5.0;
10.0 6.0;
11.0 7.0;
12.0 8.0]
@test output2 == [0.189981;
0.977349;
0.35132;
0.00704976;;]
g = Flux.batch([g1, g2])
output3 = topk_nodes(g, :x, 2; sortby = 4)
@test output3 == [9.0 5.0 0.318907 0.991791;
10.0 6.0 0.547022 0.680538;
11.0 7.0 0.921823 0.494715;
12.0 8.0 0.451793 0.0189275]
end

@testset "topk_edges" begin
A = [0.157163 0.561874 0.886584 0.0475203 0.72576 0.815986;
0.852048 0.974619 0.0345627 0.874303 0.614322 0.113491]
g1 = rand_graph(5, 6, edata = (x = A,))
output1 = topk_edges(g1, :x, 2)
@test output1 == [0.886584 0.815986;
0.974619 0.874303]
end
end