Skip to content
1 change: 1 addition & 0 deletions src/GraphNeuralNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ export
GlobalAttentionPool,
TopKPool,
topk_index,
WeigthAndSumPool,

# mldatasets
mldataset2gnngraph
Expand Down
32 changes: 32 additions & 0 deletions src/layers/pool.jl
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,35 @@ function topk_index(y::AbstractVector, k::Int)
end

topk_index(y::Adjoint, k::Int) = topk_index(y', k)

"""
WeigthAndSumPool(in_feats)

WeigthAndSum sum pooling layer.
Takes a graph and the node features as inputs, computes the weights for each node and perform a weighted sum.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Takes a graph and the node features as inputs, computes the weights for each node and perform a weighted sum.
In the forward pass, takes a graph and the node features as inputs, computes the weights for each node and performs a weighted sum.

Should also describe how the weights are computed and describe the in_feats constructor argument.

"""
struct WeigthAndSumPool
in_feats::Int
dense_layer::Dense
end

@functor WeigthAndSumPool

Flux.trainable(ws::WeigthAndSumPool) = (ws.dense_layer)

function WeigthAndSumPool(in_feats::Int)
dense_layer = Dense(in_feats, 1, sigmoid; bias = true)
WeigthAndSumPool(in_feats, dense_layer)
end

function (ws::WeigthAndSumPool)(g::GNNGraph, x::AbstractArray)
atom_weighting = ws.dense_layer
return reduce_nodes(+, g, atom_weighting(x) .* x)
end

function (ws::WeigthAndSumPool)(g::GNNGraph, x::CuArray)
atom_weighting = ws.dense_layer |> gpu
return reduce_nodes(+, g, atom_weighting(x) .* x)
end

(ws::WeigthAndSumPool)(g::GNNGraph) = GNNGraph(g, gdata = ws(g, node_features(g)))
16 changes: 16 additions & 0 deletions test/layers/pool.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,22 @@ end
y = p(X)
@test size(y) == (in_channel, k)
end

@testset "WeigthAndSumPool" begin
n = 3
chin = 5
ng = 3

ws = WeigthAndSumPool(chin)
g = GNNGraph(rand_graph(n, 4), ndata = rand(Float32, chin, n), graph_type = GRAPH_T)

test_layer(ws, g, rtol = 1e-5, outtype = :graph,outsize = (chin, 1))
g_batch = Flux.batch([GNNGraph(rand_graph(n, 4),
ndata = rand(Float32, chin, n),
graph_type = GRAPH_T)
for i in 1:ng])
test_layer(ws, g_batch, rtol = 1e-5,outtype = :graph, outsize = (chin, ng))
end
end

@testset "topk_index" begin
Expand Down