diff --git a/src/GraphNeuralNetworks.jl b/src/GraphNeuralNetworks.jl index bcd65aabe..99a14dbfb 100644 --- a/src/GraphNeuralNetworks.jl +++ b/src/GraphNeuralNetworks.jl @@ -73,6 +73,7 @@ export Set2Set, TopKPool, topk_index, + WeigthAndSumPool, # mldatasets mldataset2gnngraph diff --git a/src/layers/pool.jl b/src/layers/pool.jl index 9b208d0dc..b03fbf59f 100644 --- a/src/layers/pool.jl +++ b/src/layers/pool.jl @@ -144,6 +144,47 @@ end topk_index(y::Adjoint, k::Int) = topk_index(y', k) +""" + WeigthAndSumPool(in_feats) + +WeigthAndSum sum pooling layer. +Takes a graph and the node features as inputs, computes the weights for each node and perform a weighted sum. + +# Example + +```julia +n = 3 +chin = 5 + +ws = WeigthAndSumPool(chin) +g = GNNGraph(rand_graph(30, 50), ndata = rand(Float32, chin, 30)) + +u = ws(g, g.ndata.x) +``` +""" +struct WeigthAndSumPool + in_feats::Int + dense_layer::Dense +end + +@functor WeigthAndSumPool + +function WeigthAndSumPool(in_feats::Int) + dense_layer = Dense(in_feats, 1, sigmoid; bias = true) + WeigthAndSumPool(in_feats, dense_layer) +end + +function (ws::WeigthAndSumPool)(g::GNNGraph, x::AbstractArray) + atom_weighting = ws.dense_layer + return reduce_nodes(+, g, atom_weighting(x) .* x) +end + +function (ws::WeigthAndSumPool)(g::GNNGraph, x::CuArray) + atom_weighting = ws.dense_layer |> gpu + return reduce_nodes(+, g, atom_weighting(x) .* x) +end + +(ws::WeigthAndSumPool)(g::GNNGraph) = GNNGraph(g, gdata = ws(g, node_features(g))) @doc raw""" Set2Set(n_in, n_iters, n_layers = 1) diff --git a/test/layers/pool.jl b/test/layers/pool.jl index 24f5d66bf..15833a857 100644 --- a/test/layers/pool.jl +++ b/test/layers/pool.jl @@ -53,6 +53,22 @@ end y = p(X) @test size(y) == (in_channel, k) end + + @testset "WeigthAndSumPool" begin + n = 3 + chin = 5 + ng = 3 + + ws = WeigthAndSumPool(chin) + g = GNNGraph(rand_graph(n, 4), ndata = rand(Float32, chin, n), graph_type = GRAPH_T) + + test_layer(ws, g, rtol = 1e-5, outtype = :graph,outsize = (chin, 1)) + g_batch = Flux.batch([GNNGraph(rand_graph(n, 4), + ndata = rand(Float32, chin, n), + graph_type = GRAPH_T) + for i in 1:ng]) + test_layer(ws, g_batch, rtol = 1e-5,outtype = :graph, outsize = (chin, ng)) + end end @testset "topk_index" begin