diff --git a/GNNLux/src/GNNLux.jl b/GNNLux/src/GNNLux.jl index 163d315b5..43cddbe8a 100644 --- a/GNNLux/src/GNNLux.jl +++ b/GNNLux/src/GNNLux.jl @@ -49,7 +49,9 @@ export TGCN, EvolveGCNO include("layers/pool.jl") -export GlobalPool +export GlobalPool, + GlobalAttentionPool, + TopKPool end #module \ No newline at end of file diff --git a/GNNLux/src/layers/pool.jl b/GNNLux/src/layers/pool.jl index 7fcc044f6..4d4b7273e 100644 --- a/GNNLux/src/layers/pool.jl +++ b/GNNLux/src/layers/pool.jl @@ -39,4 +39,95 @@ end (l::GlobalPool)(g::GNNGraph, x::AbstractArray, ps, st) = GNNlib.global_pool(l, g, x), st -(l::GlobalPool)(g::GNNGraph) = GNNGraph(g, gdata = l(g, node_features(g), ps, st)) \ No newline at end of file +(l::GlobalPool)(g::GNNGraph) = GNNGraph(g, gdata = l(g, node_features(g), ps, st)) + +@doc raw""" + GlobalAttentionPool(fgate, ffeat=identity) + +Global soft attention layer from the [Gated Graph Sequence Neural +Networks](https://arxiv.org/abs/1511.05493) paper + +```math +\mathbf{u}_V = \sum_{i\in V} \alpha_i\, f_{feat}(\mathbf{x}_i) +``` + +where the coefficients ``\alpha_i`` are given by a [`GNNLib.softmax_nodes`](https://juliagraphs.org/GraphNeuralNetworks.jl/docs/GNNlib.jl/stable/api/utils/#GNNlib.softmax_nodes) +operation: + +```math +\alpha_i = \frac{e^{f_{gate}(\mathbf{x}_i)}} + {\sum_{i'\in V} e^{f_{gate}(\mathbf{x}_{i'})}}. +``` + +# Arguments + +- `fgate`: The function ``f_{gate}: \mathbb{R}^{D_{in}} \to \mathbb{R}``. + It is typically expressed by a neural network. + +- `ffeat`: The function ``f_{feat}: \mathbb{R}^{D_{in}} \to \mathbb{R}^{D_{out}}``. + It is typically expressed by a neural network. + +# Examples + +```julia +using Graphs, LuxCore, Lux, GNNLux, Random + +rng = Random.default_rng() +chin = 6 +chout = 5 + +fgate = Dense(chin, 1) +ffeat = Dense(chin, chout) +pool = GlobalAttentionPool(fgate, ffeat) + +g = batch([GNNGraph(Graphs.random_regular_graph(10, 4), + ndata=rand(Float32, chin, 10)) + for i=1:3]) + +ps = (fgate = LuxCore.initialparameters(rng, fgate), ffeat = LuxCore.initialparameters(rng, ffeat)) +st = (fgate = LuxCore.initialstates(rng, fgate), ffeat = LuxCore.initialstates(rng, ffeat)) + +u, st = pool(g, g.ndata.x, ps, st) + +@assert size(u) == (chout, g.num_graphs) +``` +""" +@concrete struct GlobalAttentionPool <: GNNContainerLayer{(:fgate, :ffeat)} + fgate + ffeat +end + +GlobalAttentionPool(fgate) = GlobalAttentionPool(fgate, identity) + +function (l::GlobalAttentionPool)(g, x, ps, st) + fgate = StatefulLuxLayer{true}(l.fgate, ps.fgate, _getstate(st, :fgate)) + ffeat = StatefulLuxLayer{true}(l.ffeat, ps.ffeat, _getstate(st, :ffeat)) + m = (; fgate, ffeat) + return GNNlib.global_attention_pool(m, g, x), st +end + +(l::GlobalAttentionPool)(g::GNNGraph) = GNNGraph(g, gdata = l(g, node_features(g), ps, st)) + +""" + TopKPool(adj, k, in_channel) + +Top-k pooling layer. + +# Arguments + +- `adj`: Adjacency matrix of a graph. +- `k`: Top-k nodes are selected to pool together. +- `in_channel`: The dimension of input channel. +""" +struct TopKPool{T, S} + A::AbstractMatrix{T} + k::Int + p::AbstractVector{S} + Ã::AbstractMatrix{T} +end + +function TopKPool(adj::AbstractMatrix, k::Int, in_channel::Int; init = glorot_uniform) + TopKPool(adj, k, init(in_channel), similar(adj, k, k)) +end + +(t::TopKPool)(x::AbstractArray, ps, st) = GNNlib.topk_pool(t, x), st diff --git a/GNNLux/test/layers/pool.jl b/GNNLux/test/layers/pool.jl index f1f7faeae..9a97812ea 100644 --- a/GNNLux/test/layers/pool.jl +++ b/GNNLux/test/layers/pool.jl @@ -1,15 +1,46 @@ @testitem "Pooling" setup=[TestModuleLux] begin using .TestModuleLux - @testset "GlobalPool" begin + @testset "Pooling" begin rng = StableRNG(1234) - g = rand_graph(rng, 10, 40) - in_dims = 3 - x = randn(rng, Float32, in_dims, 10) - - @testset "GCNConv" begin + @testset "GlobalPool" begin + g = rand_graph(rng, 10, 40) + in_dims = 3 + x = randn(rng, Float32, in_dims, 10) l = GlobalPool(mean) test_lux_layer(rng, l, g, x, sizey=(in_dims,1)) end + @testset "GlobalAttentionPool" begin + n = 10 + chin = 6 + chout = 5 + ng = 3 + g = batch([GNNGraph(rand_graph(rng, 10, 40), + ndata = rand(Float32, chin, n)) for i in 1:ng]) + + fgate = Dense(chin, 1) + ffeat = Dense(chin, chout) + l = GlobalAttentionPool(fgate, ffeat) + + test_lux_layer(rng, l, g, g.ndata.x, sizey=(chout,ng), container=true) + end + + @testset "TopKPool" begin + N = 10 + k, in_channel = 4, 7 + X = rand(in_channel, N) + ps = (;) + st = (;) + for T in [Bool, Float64] + adj = rand(T, N, N) + p = GNNLux.TopKPool(adj, k, in_channel) + @test eltype(p.p) === Float32 + @test size(p.p) == (in_channel,) + @test eltype(p.Ã) === T + @test size(p.Ã) == (k, k) + y, st = p(X, ps, st) + @test size(y) == (in_channel, k) + end + end end end diff --git a/GNNlib/src/layers/pool.jl b/GNNlib/src/layers/pool.jl index 40f983689..83d9c066f 100644 --- a/GNNlib/src/layers/pool.jl +++ b/GNNlib/src/layers/pool.jl @@ -29,8 +29,8 @@ topk_index(y::Adjoint, k::Int) = topk_index(y', k) function set2set_pool(l, g::GNNGraph, x::AbstractMatrix) n_in = size(x, 1) qstar = zeros_like(x, (2*n_in, g.num_graphs)) - h = zeros_like(l.lstm.Wh, size(l.lstm.Wh, 2)) - c = zeros_like(l.lstm.Wh, size(l.lstm.Wh, 2)) + h = zeros_like(l.Wh, size(l.Wh, 2)) + c = zeros_like(l.Wh, size(l.Wh, 2)) state = (h, c) for t in 1:l.num_iters q, state = l.lstm(qstar, state) # [n_in, n_graphs] diff --git a/GraphNeuralNetworks/src/layers/pool.jl b/GraphNeuralNetworks/src/layers/pool.jl index a7d9ceaef..abcb4c115 100644 --- a/GraphNeuralNetworks/src/layers/pool.jl +++ b/GraphNeuralNetworks/src/layers/pool.jl @@ -61,10 +61,10 @@ operation: # Arguments - `fgate`: The function ``f_{gate}: \mathbb{R}^{D_{in}} \to \mathbb{R}``. - It is tipically expressed by a neural network. + It is typically expressed by a neural network. - `ffeat`: The function ``f_{feat}: \mathbb{R}^{D_{in}} \to \mathbb{R}^{D_{out}}``. - It is tipically expressed by a neural network. + It is typically expressed by a neural network. # Examples @@ -156,7 +156,8 @@ function Set2Set(n_in::Int, n_iters::Int, n_layers::Int = 1) end function (l::Set2Set)(g, x) - return GNNlib.set2set_pool(l, g, x) + m = (; l.lstm, l.num_iters, Wh = l.lstm.Wh) + return GNNlib.set2set_pool(m, g, x) end (l::Set2Set)(g::GNNGraph) = GNNGraph(g, gdata = l(g, node_features(g)))