From a1733066a2254c335540939388d756f79b4023f2 Mon Sep 17 00:00:00 2001 From: Aurora Rossi Date: Fri, 3 Jan 2025 14:21:30 +0100 Subject: [PATCH 01/13] Add `GlobalAttentionPool` --- GNNLux/src/GNNLux.jl | 3 +- GNNLux/src/layers/pool.jl | 69 +++++++++++++++++++++++++- GNNLux/test/layers/pool.jl | 12 +++-- GraphNeuralNetworks/src/layers/pool.jl | 4 +- 4 files changed, 81 insertions(+), 7 deletions(-) diff --git a/GNNLux/src/GNNLux.jl b/GNNLux/src/GNNLux.jl index 163d315b5..2c25811ef 100644 --- a/GNNLux/src/GNNLux.jl +++ b/GNNLux/src/GNNLux.jl @@ -49,7 +49,8 @@ export TGCN, EvolveGCNO include("layers/pool.jl") -export GlobalPool +export GlobalPool, + GlobalAttentionPool end #module \ No newline at end of file diff --git a/GNNLux/src/layers/pool.jl b/GNNLux/src/layers/pool.jl index 7fcc044f6..9f98f449c 100644 --- a/GNNLux/src/layers/pool.jl +++ b/GNNLux/src/layers/pool.jl @@ -39,4 +39,71 @@ end (l::GlobalPool)(g::GNNGraph, x::AbstractArray, ps, st) = GNNlib.global_pool(l, g, x), st -(l::GlobalPool)(g::GNNGraph) = GNNGraph(g, gdata = l(g, node_features(g), ps, st)) \ No newline at end of file +(l::GlobalPool)(g::GNNGraph) = GNNGraph(g, gdata = l(g, node_features(g), ps, st)) + +@doc raw""" + GlobalAttentionPool(fgate, ffeat=identity) + +Global soft attention layer from the [Gated Graph Sequence Neural +Networks](https://arxiv.org/abs/1511.05493) paper + +```math +\mathbf{u}_V = \sum_{i\in V} \alpha_i\, f_{feat}(\mathbf{x}_i) +``` + +where the coefficients ``\alpha_i`` are given by a [`softmax_nodes`](@ref) +operation: + +```math +\alpha_i = \frac{e^{f_{gate}(\mathbf{x}_i)}} + {\sum_{i'\in V} e^{f_{gate}(\mathbf{x}_{i'})}}. +``` + +# Arguments + +- `fgate`: The function ``f_{gate}: \mathbb{R}^{D_{in}} \to \mathbb{R}``. + It is typically expressed by a neural network. + +- `ffeat`: The function ``f_{feat}: \mathbb{R}^{D_{in}} \to \mathbb{R}^{D_{out}}``. + It is typically expressed by a neural network. + +# Examples + +```julia +using Graphs, LuxCore, Lux, GNNLux, Random + +rng = Random.default_rng() +chin = 6 +chout = 5 + +fgate = Dense(chin, 1) +ffeat = Dense(chin, chout) +pool = GlobalAttentionPool(fgate, ffeat) + +g = batch([GNNGraph(Graphs.random_regular_graph(10, 4), + ndata=rand(Float32, chin, 10)) + for i=1:3]) + +ps = (fgate = LuxCore.initialparameters(rng, fgate), ffeat = LuxCore.initialparameters(rng, ffeat)) +st = (fgate = LuxCore.initialstates(rng, fgate), ffeat = LuxCore.initialstates(rng, ffeat)) + +u, st = pool(g, g.ndata.x, ps, st) + +@assert size(u) == (chout, g.num_graphs) +``` +""" +struct GlobalAttentionPool{G, F} + fgate::G + ffeat::F +end + +GlobalAttentionPool(fgate) = GlobalAttentionPool(fgate, identity) + +function (l::GlobalAttentionPool)(g, x, ps, st) + fgate = StatefulLuxLayer{true}(l.fgate, ps.fgate, _getstate(st, :fgate)) + ffeat = StatefulLuxLayer{true}(l.ffeat, ps.ffeat, _getstate(st, :ffeat)) + m = (; fgate, ffeat) + return GNNlib.global_attention_pool(m, g, x), st +end + +(l::GlobalAttentionPool)(g::GNNGraph) = GNNGraph(g, gdata = l(g, node_features(g), ps, st)) \ No newline at end of file diff --git a/GNNLux/test/layers/pool.jl b/GNNLux/test/layers/pool.jl index f1f7faeae..69dee0926 100644 --- a/GNNLux/test/layers/pool.jl +++ b/GNNLux/test/layers/pool.jl @@ -1,15 +1,21 @@ @testitem "Pooling" setup=[TestModuleLux] begin using .TestModuleLux - @testset "GlobalPool" begin + @testset "Pooling" begin rng = StableRNG(1234) g = rand_graph(rng, 10, 40) in_dims = 3 x = randn(rng, Float32, in_dims, 10) - @testset "GCNConv" begin - l = GlobalPool(mean) + @testset "GlobalPool" begin + l = GNNLux.GlobalPool(mean) test_lux_layer(rng, l, g, x, sizey=(in_dims,1)) end + @testset "GlobalAttentionPool" begin + fgate = Dense(in_dims, 1) + ffeat = Dense(in_dims, in_dims) + l = GNNLux.GlobalAttentionPool(fgate, ffeat) + test_lux_layer(rng, l, g, x, sizey=(in_dims,1), container=true) + end end end diff --git a/GraphNeuralNetworks/src/layers/pool.jl b/GraphNeuralNetworks/src/layers/pool.jl index a7d9ceaef..b08803438 100644 --- a/GraphNeuralNetworks/src/layers/pool.jl +++ b/GraphNeuralNetworks/src/layers/pool.jl @@ -61,10 +61,10 @@ operation: # Arguments - `fgate`: The function ``f_{gate}: \mathbb{R}^{D_{in}} \to \mathbb{R}``. - It is tipically expressed by a neural network. + It is typically expressed by a neural network. - `ffeat`: The function ``f_{feat}: \mathbb{R}^{D_{in}} \to \mathbb{R}^{D_{out}}``. - It is tipically expressed by a neural network. + It is typically expressed by a neural network. # Examples From 4adef57b83440281a1326ca8c169a43e5f2e7c5a Mon Sep 17 00:00:00 2001 From: Aurora Rossi Date: Fri, 3 Jan 2025 14:22:36 +0100 Subject: [PATCH 02/13] Fix --- GNNLux/test/layers/pool.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/GNNLux/test/layers/pool.jl b/GNNLux/test/layers/pool.jl index 69dee0926..48f9ee65c 100644 --- a/GNNLux/test/layers/pool.jl +++ b/GNNLux/test/layers/pool.jl @@ -8,13 +8,13 @@ x = randn(rng, Float32, in_dims, 10) @testset "GlobalPool" begin - l = GNNLux.GlobalPool(mean) + l = GlobalPool(mean) test_lux_layer(rng, l, g, x, sizey=(in_dims,1)) end @testset "GlobalAttentionPool" begin fgate = Dense(in_dims, 1) ffeat = Dense(in_dims, in_dims) - l = GNNLux.GlobalAttentionPool(fgate, ffeat) + l = GlobalAttentionPool(fgate, ffeat) test_lux_layer(rng, l, g, x, sizey=(in_dims,1), container=true) end end From e76d150c59d1b4d76c2e06d19df7a62bc98fa7bd Mon Sep 17 00:00:00 2001 From: Aurora Rossi Date: Fri, 3 Jan 2025 14:49:10 +0100 Subject: [PATCH 03/13] Fixes --- GNNLux/src/layers/pool.jl | 8 ++++---- GNNLux/test/layers/pool.jl | 21 ++++++++++++++------- 2 files changed, 18 insertions(+), 11 deletions(-) diff --git a/GNNLux/src/layers/pool.jl b/GNNLux/src/layers/pool.jl index 9f98f449c..22bd6e6e2 100644 --- a/GNNLux/src/layers/pool.jl +++ b/GNNLux/src/layers/pool.jl @@ -51,7 +51,7 @@ Networks](https://arxiv.org/abs/1511.05493) paper \mathbf{u}_V = \sum_{i\in V} \alpha_i\, f_{feat}(\mathbf{x}_i) ``` -where the coefficients ``\alpha_i`` are given by a [`softmax_nodes`](@ref) +where the coefficients ``\alpha_i`` are given by a [`GNNLib.softmax_nodes`](@ref) operation: ```math @@ -92,9 +92,9 @@ u, st = pool(g, g.ndata.x, ps, st) @assert size(u) == (chout, g.num_graphs) ``` """ -struct GlobalAttentionPool{G, F} - fgate::G - ffeat::F +struct GlobalAttentionPool <: GNNContainerLayer{(:fgate, :ffeat)} + fgate + ffeat end GlobalAttentionPool(fgate) = GlobalAttentionPool(fgate, identity) diff --git a/GNNLux/test/layers/pool.jl b/GNNLux/test/layers/pool.jl index 48f9ee65c..0a6aff3d4 100644 --- a/GNNLux/test/layers/pool.jl +++ b/GNNLux/test/layers/pool.jl @@ -3,19 +3,26 @@ @testset "Pooling" begin rng = StableRNG(1234) - g = rand_graph(rng, 10, 40) - in_dims = 3 - x = randn(rng, Float32, in_dims, 10) - @testset "GlobalPool" begin + g = rand_graph(rng, 10, 40) + in_dims = 3 + x = randn(rng, Float32, in_dims, 10) l = GlobalPool(mean) test_lux_layer(rng, l, g, x, sizey=(in_dims,1)) end @testset "GlobalAttentionPool" begin - fgate = Dense(in_dims, 1) - ffeat = Dense(in_dims, in_dims) + n = 10 + chin = 6 + chout = 5 + ng = 3 + g = batch([GNNGraph(rand_graph(rng, 10, 40), + ndata = rand(Float32, chin, n)) for i in 1:ng]) + + fgate = Dense(chin, 1) + ffeat = Dense(chin, chout) l = GlobalAttentionPool(fgate, ffeat) - test_lux_layer(rng, l, g, x, sizey=(in_dims,1), container=true) + + test_lux_layer(rng, l, g, g.ndata.x, sizey=(chout,ng), container=true) end end end From bf947588a6bcc62001e7d74a06f2a344d8b0e44a Mon Sep 17 00:00:00 2001 From: Aurora Rossi Date: Fri, 3 Jan 2025 14:59:13 +0100 Subject: [PATCH 04/13] Fix docs --- GNNLux/src/layers/pool.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/GNNLux/src/layers/pool.jl b/GNNLux/src/layers/pool.jl index 22bd6e6e2..578e43f04 100644 --- a/GNNLux/src/layers/pool.jl +++ b/GNNLux/src/layers/pool.jl @@ -51,7 +51,7 @@ Networks](https://arxiv.org/abs/1511.05493) paper \mathbf{u}_V = \sum_{i\in V} \alpha_i\, f_{feat}(\mathbf{x}_i) ``` -where the coefficients ``\alpha_i`` are given by a [`GNNLib.softmax_nodes`](@ref) +where the coefficients ``\alpha_i`` are given by a [`GNNLib.softmax_nodes`](https://juliagraphs.org/GraphNeuralNetworks.jl/docs/GNNlib.jl/stable/api/utils/#GNNlib.softmax_nodes) operation: ```math From 321dee785afa040f78d0e9e372aabec5457d4508 Mon Sep 17 00:00:00 2001 From: Aurora Rossi Date: Fri, 3 Jan 2025 15:48:48 +0100 Subject: [PATCH 05/13] Add `TopK` pooling --- GNNLux/src/layers/pool.jl | 64 +++++++++++++++++++++++++++++++++++++- GNNLux/test/layers/pool.jl | 18 +++++++++++ 2 files changed, 81 insertions(+), 1 deletion(-) diff --git a/GNNLux/src/layers/pool.jl b/GNNLux/src/layers/pool.jl index 578e43f04..908834ba3 100644 --- a/GNNLux/src/layers/pool.jl +++ b/GNNLux/src/layers/pool.jl @@ -106,4 +106,66 @@ function (l::GlobalAttentionPool)(g, x, ps, st) return GNNlib.global_attention_pool(m, g, x), st end -(l::GlobalAttentionPool)(g::GNNGraph) = GNNGraph(g, gdata = l(g, node_features(g), ps, st)) \ No newline at end of file +(l::GlobalAttentionPool)(g::GNNGraph) = GNNGraph(g, gdata = l(g, node_features(g), ps, st)) + +""" + TopKPool(adj, k, in_channel) + +Top-k pooling layer. + +# Arguments + +- `adj`: Adjacency matrix of a graph. +- `k`: Top-k nodes are selected to pool together. +- `in_channel`: The dimension of input channel. +""" +struct TopKPool{T, S} + A::AbstractMatrix{T} + k::Int + p::AbstractVector{S} + Ã::AbstractMatrix{T} +end + +function TopKPool(adj::AbstractMatrix, k::Int, in_channel::Int; init = glorot_uniform) + TopKPool(adj, k, init(in_channel), similar(adj, k, k)) +end + +(t::TopKPool)(x::AbstractArray, ps, st) = GNNlib.topk_pool(t, x) + + +@doc raw""" + Set2Set(n_in, n_iters, n_layers = 1) + +Set2Set layer from the paper [Order Matters: Sequence to sequence for sets](https://arxiv.org/abs/1511.06391). + +For each graph in the batch, the layer computes an output vector of size `2*n_in` by iterating the following steps `n_iters` times: +```math +\mathbf{q} = \mathrm{LSTM}(\mathbf{q}_{t-1}^*) +\alpha_{i} = \frac{\exp(\mathbf{q}^T \mathbf{x}_i)}{\sum_{j=1}^N \exp(\mathbf{q}^T \mathbf{x}_j)} +\mathbf{r} = \sum_{i=1}^N \alpha_{i} \mathbf{x}_i +\mathbf{q}^*_t = [\mathbf{q}; \mathbf{r}] +``` +where `N` is the number of nodes in the graph, `LSTM` is a Long-Short-Term-Memory network with `n_layers` layers, input size `2*n_in` and output size `n_in`. + +Given a batch of graphs `g` and node features `x`, the layer returns a matrix of size `(2*n_in, n_graphs)`. +``` +""" +struct Set2Set{L} <: GNNContainerLayer{(:lstm,)} + lstm::L + num_iters::Int +end + +function Set2Set(n_in::Int, n_iters::Int, n_layers::Int = 1) + @assert n_layers == 1 "multiple layers not implemented yet" #TODO + n_out = 2 * n_in + lstm = Lux.LSTMCell(n_out => n_in) + return Set2Set(lstm, n_iters) +end + +function (l::Set2Set)(g, x, ps, st) + lstm = StatefulLuxLayer{true}(l.lstm, ps.lstm, _getstate(st, :lstm)) + m = (; lstm, Wh = ps.lstm.weight_hh) + return GNNlib.set2set_pool(m, g, x) +end + +(l::Set2Set)(g::GNNGraph) = GNNGraph(g, gdata = l(g, node_features(g), ps, st)) diff --git a/GNNLux/test/layers/pool.jl b/GNNLux/test/layers/pool.jl index 0a6aff3d4..b57c8906e 100644 --- a/GNNLux/test/layers/pool.jl +++ b/GNNLux/test/layers/pool.jl @@ -24,5 +24,23 @@ test_lux_layer(rng, l, g, g.ndata.x, sizey=(chout,ng), container=true) end + + @testset "TopKPool" begin + N = 10 + k, in_channel = 4, 7 + X = rand(in_channel, N) + ps = (;) + st = (;) + for T in [Bool, Float64] + adj = rand(T, N, N) + p = GNNLux.TopKPool(adj, k, in_channel) + @test eltype(p.p) === Float32 + @test size(p.p) == (in_channel,) + @test eltype(p.Ã) === T + @test size(p.Ã) == (k, k) + y = p(X, ps, st) + @test size(y) == (in_channel, k) + end + end end end From 4283e36587d5f09bc059ba7ed909146a729f7b4d Mon Sep 17 00:00:00 2001 From: Aurora Rossi Date: Fri, 3 Jan 2025 15:49:16 +0100 Subject: [PATCH 06/13] Add `Set2Set` pooling layer --- GNNlib/src/layers/pool.jl | 4 ++-- GraphNeuralNetworks/src/layers/pool.jl | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/GNNlib/src/layers/pool.jl b/GNNlib/src/layers/pool.jl index 40f983689..83d9c066f 100644 --- a/GNNlib/src/layers/pool.jl +++ b/GNNlib/src/layers/pool.jl @@ -29,8 +29,8 @@ topk_index(y::Adjoint, k::Int) = topk_index(y', k) function set2set_pool(l, g::GNNGraph, x::AbstractMatrix) n_in = size(x, 1) qstar = zeros_like(x, (2*n_in, g.num_graphs)) - h = zeros_like(l.lstm.Wh, size(l.lstm.Wh, 2)) - c = zeros_like(l.lstm.Wh, size(l.lstm.Wh, 2)) + h = zeros_like(l.Wh, size(l.Wh, 2)) + c = zeros_like(l.Wh, size(l.Wh, 2)) state = (h, c) for t in 1:l.num_iters q, state = l.lstm(qstar, state) # [n_in, n_graphs] diff --git a/GraphNeuralNetworks/src/layers/pool.jl b/GraphNeuralNetworks/src/layers/pool.jl index b08803438..03ed25cd7 100644 --- a/GraphNeuralNetworks/src/layers/pool.jl +++ b/GraphNeuralNetworks/src/layers/pool.jl @@ -156,7 +156,8 @@ function Set2Set(n_in::Int, n_iters::Int, n_layers::Int = 1) end function (l::Set2Set)(g, x) - return GNNlib.set2set_pool(l, g, x) + m = (; lstm, Wh = lstm.Wh) + return GNNlib.set2set_pool(m, g, x) end (l::Set2Set)(g::GNNGraph) = GNNGraph(g, gdata = l(g, node_features(g))) From c8418990ac006536a81629bdef5e5cd1930e3a36 Mon Sep 17 00:00:00 2001 From: Aurora Rossi Date: Fri, 3 Jan 2025 16:30:53 +0100 Subject: [PATCH 07/13] Fix --- GraphNeuralNetworks/src/layers/pool.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/GraphNeuralNetworks/src/layers/pool.jl b/GraphNeuralNetworks/src/layers/pool.jl index 03ed25cd7..af2918d1a 100644 --- a/GraphNeuralNetworks/src/layers/pool.jl +++ b/GraphNeuralNetworks/src/layers/pool.jl @@ -156,7 +156,7 @@ function Set2Set(n_in::Int, n_iters::Int, n_layers::Int = 1) end function (l::Set2Set)(g, x) - m = (; lstm, Wh = lstm.Wh) + m = (; l.lstm, l.num_iters, Wh = lstm.Wh) return GNNlib.set2set_pool(m, g, x) end From bbea5a682f404506ac5fef1843969bf7c1db87bb Mon Sep 17 00:00:00 2001 From: Aurora Rossi Date: Fri, 3 Jan 2025 17:21:43 +0100 Subject: [PATCH 08/13] Fix --- GraphNeuralNetworks/src/layers/pool.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/GraphNeuralNetworks/src/layers/pool.jl b/GraphNeuralNetworks/src/layers/pool.jl index af2918d1a..abcb4c115 100644 --- a/GraphNeuralNetworks/src/layers/pool.jl +++ b/GraphNeuralNetworks/src/layers/pool.jl @@ -156,7 +156,7 @@ function Set2Set(n_in::Int, n_iters::Int, n_layers::Int = 1) end function (l::Set2Set)(g, x) - m = (; l.lstm, l.num_iters, Wh = lstm.Wh) + m = (; l.lstm, l.num_iters, Wh = l.lstm.Wh) return GNNlib.set2set_pool(m, g, x) end From c4986f653c740950eb6ab309488d99ab84c2f770 Mon Sep 17 00:00:00 2001 From: Aurora Rossi Date: Fri, 3 Jan 2025 18:20:52 +0100 Subject: [PATCH 09/13] Add export tokpool --- GNNLux/src/GNNLux.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/GNNLux/src/GNNLux.jl b/GNNLux/src/GNNLux.jl index 2c25811ef..43cddbe8a 100644 --- a/GNNLux/src/GNNLux.jl +++ b/GNNLux/src/GNNLux.jl @@ -50,7 +50,8 @@ export TGCN, include("layers/pool.jl") export GlobalPool, - GlobalAttentionPool + GlobalAttentionPool, + TopKPool end #module \ No newline at end of file From 6df678b82beb5b697877d53530684ac66753716b Mon Sep 17 00:00:00 2001 From: Aurora Rossi Date: Fri, 3 Jan 2025 18:21:51 +0100 Subject: [PATCH 10/13] Remove Set2Set not working --- GNNLux/src/layers/pool.jl | 38 -------------------------------------- 1 file changed, 38 deletions(-) diff --git a/GNNLux/src/layers/pool.jl b/GNNLux/src/layers/pool.jl index 908834ba3..c154aa2d7 100644 --- a/GNNLux/src/layers/pool.jl +++ b/GNNLux/src/layers/pool.jl @@ -131,41 +131,3 @@ function TopKPool(adj::AbstractMatrix, k::Int, in_channel::Int; init = glorot_un end (t::TopKPool)(x::AbstractArray, ps, st) = GNNlib.topk_pool(t, x) - - -@doc raw""" - Set2Set(n_in, n_iters, n_layers = 1) - -Set2Set layer from the paper [Order Matters: Sequence to sequence for sets](https://arxiv.org/abs/1511.06391). - -For each graph in the batch, the layer computes an output vector of size `2*n_in` by iterating the following steps `n_iters` times: -```math -\mathbf{q} = \mathrm{LSTM}(\mathbf{q}_{t-1}^*) -\alpha_{i} = \frac{\exp(\mathbf{q}^T \mathbf{x}_i)}{\sum_{j=1}^N \exp(\mathbf{q}^T \mathbf{x}_j)} -\mathbf{r} = \sum_{i=1}^N \alpha_{i} \mathbf{x}_i -\mathbf{q}^*_t = [\mathbf{q}; \mathbf{r}] -``` -where `N` is the number of nodes in the graph, `LSTM` is a Long-Short-Term-Memory network with `n_layers` layers, input size `2*n_in` and output size `n_in`. - -Given a batch of graphs `g` and node features `x`, the layer returns a matrix of size `(2*n_in, n_graphs)`. -``` -""" -struct Set2Set{L} <: GNNContainerLayer{(:lstm,)} - lstm::L - num_iters::Int -end - -function Set2Set(n_in::Int, n_iters::Int, n_layers::Int = 1) - @assert n_layers == 1 "multiple layers not implemented yet" #TODO - n_out = 2 * n_in - lstm = Lux.LSTMCell(n_out => n_in) - return Set2Set(lstm, n_iters) -end - -function (l::Set2Set)(g, x, ps, st) - lstm = StatefulLuxLayer{true}(l.lstm, ps.lstm, _getstate(st, :lstm)) - m = (; lstm, Wh = ps.lstm.weight_hh) - return GNNlib.set2set_pool(m, g, x) -end - -(l::Set2Set)(g::GNNGraph) = GNNGraph(g, gdata = l(g, node_features(g), ps, st)) From e80872cccdd2c11c0b623a04df5c6138013a595f Mon Sep 17 00:00:00 2001 From: Aurora Rossi <65721467+aurorarossi@users.noreply.github.com> Date: Mon, 6 Jan 2025 22:51:09 +0100 Subject: [PATCH 11/13] Add st Co-authored-by: Carlo Lucibello --- GNNLux/src/layers/pool.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/GNNLux/src/layers/pool.jl b/GNNLux/src/layers/pool.jl index c154aa2d7..7742e495f 100644 --- a/GNNLux/src/layers/pool.jl +++ b/GNNLux/src/layers/pool.jl @@ -130,4 +130,4 @@ function TopKPool(adj::AbstractMatrix, k::Int, in_channel::Int; init = glorot_un TopKPool(adj, k, init(in_channel), similar(adj, k, k)) end -(t::TopKPool)(x::AbstractArray, ps, st) = GNNlib.topk_pool(t, x) +(t::TopKPool)(x::AbstractArray, ps, st) = GNNlib.topk_pool(t, x), st From c04853245944c2302e569b70946682c5a9abb84d Mon Sep 17 00:00:00 2001 From: Aurora Rossi Date: Mon, 6 Jan 2025 22:52:39 +0100 Subject: [PATCH 12/13] Add concrete --- GNNLux/src/layers/pool.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/GNNLux/src/layers/pool.jl b/GNNLux/src/layers/pool.jl index c154aa2d7..2669ce2c4 100644 --- a/GNNLux/src/layers/pool.jl +++ b/GNNLux/src/layers/pool.jl @@ -92,7 +92,7 @@ u, st = pool(g, g.ndata.x, ps, st) @assert size(u) == (chout, g.num_graphs) ``` """ -struct GlobalAttentionPool <: GNNContainerLayer{(:fgate, :ffeat)} +@concrete struct GlobalAttentionPool <: GNNContainerLayer{(:fgate, :ffeat)} fgate ffeat end From e5750669c327be10f24d87c4cd162285e107a96e Mon Sep 17 00:00:00 2001 From: Aurora Rossi Date: Tue, 7 Jan 2025 11:22:03 +0100 Subject: [PATCH 13/13] Fix `TopK` pool test --- GNNLux/test/layers/pool.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/GNNLux/test/layers/pool.jl b/GNNLux/test/layers/pool.jl index b57c8906e..9a97812ea 100644 --- a/GNNLux/test/layers/pool.jl +++ b/GNNLux/test/layers/pool.jl @@ -38,7 +38,7 @@ @test size(p.p) == (in_channel,) @test eltype(p.Ã) === T @test size(p.Ã) == (k, k) - y = p(X, ps, st) + y, st = p(X, ps, st) @test size(y) == (in_channel, k) end end