Skip to content
1 change: 1 addition & 0 deletions src/GraphNeuralNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ export
GlobalAttentionPool,
TopKPool,
topk_index,
WeigthAndSumPool,

# mldatasets
mldataset2gnngraph
Expand Down
21 changes: 21 additions & 0 deletions src/layers/pool.jl
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,24 @@ function topk_index(y::AbstractVector, k::Int)
end

topk_index(y::Adjoint, k::Int) = topk_index(y', k)

"""
WeigthAndSumPool(in_feats)

WeigthAndSum sum pooling layer. Compute the weights for each node and perform a weighted sum.
"""
struct WeigthAndSumPool
in_feats::Int
end

function (ws::WeigthAndSumPool)(g::GNNGraph, x::AbstractArray)
atom_weighting = Dense(ws.in_feats, 1, sigmoid; bias = true)
return reduce_nodes(+, g, atom_weighting(x) .* x)
end

function (ws::WeigthAndSumPool)(g::GNNGraph, x::CuArray)
atom_weighting = Dense(ws.in_feats, 1, sigmoid; bias = true) |> gpu
return reduce_nodes(+, g, atom_weighting(x) .* x)
end

(ws::WeigthAndSumPool)(g::GNNGraph) = GNNGraph(g, gdata = ws(g, node_features(g)))
16 changes: 16 additions & 0 deletions test/layers/pool.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,20 @@
@test topk_index(X, 4) == [1, 2, 3, 4]
@test topk_index(X', 4) == [1, 2, 3, 4]
end

@testset "WeigthAndSumPool" begin
n = 3
chin = 5
ng = 3

ws = WeigthAndSumPool(chin)
g = GNNGraph(rand_graph(n, 4), ndata = rand(Float32, chin, n), graph_type = GRAPH_T)

test_layer(ws, g, rtol = 1e-5, outtype = :graph,outsize = (chin, 1))
g_batch = Flux.batch([GNNGraph(rand_graph(n, 4),
ndata = rand(Float32, chin, n),
graph_type = GRAPH_T)
for i in 1:ng])
test_layer(ws, g_batch, rtol = 1e-5,outtype = :graph, outsize = (chin, ng))
end
end