Skip to content

Commit 2dd14fd

Browse files
authored
[GNNLux] Add pooling layers (#576)
1 parent 24eae1f commit 2dd14fd

File tree

5 files changed

+138
-13
lines changed

5 files changed

+138
-13
lines changed

GNNLux/src/GNNLux.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,9 @@ export TGCN,
4949
EvolveGCNO
5050

5151
include("layers/pool.jl")
52-
export GlobalPool
52+
export GlobalPool,
53+
GlobalAttentionPool,
54+
TopKPool
5355

5456
end #module
5557

GNNLux/src/layers/pool.jl

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,4 +39,95 @@ end
3939

4040
(l::GlobalPool)(g::GNNGraph, x::AbstractArray, ps, st) = GNNlib.global_pool(l, g, x), st
4141

42-
(l::GlobalPool)(g::GNNGraph) = GNNGraph(g, gdata = l(g, node_features(g), ps, st))
42+
(l::GlobalPool)(g::GNNGraph) = GNNGraph(g, gdata = l(g, node_features(g), ps, st))
43+
44+
@doc raw"""
45+
GlobalAttentionPool(fgate, ffeat=identity)
46+
47+
Global soft attention layer from the [Gated Graph Sequence Neural
48+
Networks](https://arxiv.org/abs/1511.05493) paper
49+
50+
```math
51+
\mathbf{u}_V = \sum_{i\in V} \alpha_i\, f_{feat}(\mathbf{x}_i)
52+
```
53+
54+
where the coefficients ``\alpha_i`` are given by a [`GNNLib.softmax_nodes`](https://juliagraphs.org/GraphNeuralNetworks.jl/docs/GNNlib.jl/stable/api/utils/#GNNlib.softmax_nodes)
55+
operation:
56+
57+
```math
58+
\alpha_i = \frac{e^{f_{gate}(\mathbf{x}_i)}}
59+
{\sum_{i'\in V} e^{f_{gate}(\mathbf{x}_{i'})}}.
60+
```
61+
62+
# Arguments
63+
64+
- `fgate`: The function ``f_{gate}: \mathbb{R}^{D_{in}} \to \mathbb{R}``.
65+
It is typically expressed by a neural network.
66+
67+
- `ffeat`: The function ``f_{feat}: \mathbb{R}^{D_{in}} \to \mathbb{R}^{D_{out}}``.
68+
It is typically expressed by a neural network.
69+
70+
# Examples
71+
72+
```julia
73+
using Graphs, LuxCore, Lux, GNNLux, Random
74+
75+
rng = Random.default_rng()
76+
chin = 6
77+
chout = 5
78+
79+
fgate = Dense(chin, 1)
80+
ffeat = Dense(chin, chout)
81+
pool = GlobalAttentionPool(fgate, ffeat)
82+
83+
g = batch([GNNGraph(Graphs.random_regular_graph(10, 4),
84+
ndata=rand(Float32, chin, 10))
85+
for i=1:3])
86+
87+
ps = (fgate = LuxCore.initialparameters(rng, fgate), ffeat = LuxCore.initialparameters(rng, ffeat))
88+
st = (fgate = LuxCore.initialstates(rng, fgate), ffeat = LuxCore.initialstates(rng, ffeat))
89+
90+
u, st = pool(g, g.ndata.x, ps, st)
91+
92+
@assert size(u) == (chout, g.num_graphs)
93+
```
94+
"""
95+
@concrete struct GlobalAttentionPool <: GNNContainerLayer{(:fgate, :ffeat)}
96+
fgate
97+
ffeat
98+
end
99+
100+
GlobalAttentionPool(fgate) = GlobalAttentionPool(fgate, identity)
101+
102+
function (l::GlobalAttentionPool)(g, x, ps, st)
103+
fgate = StatefulLuxLayer{true}(l.fgate, ps.fgate, _getstate(st, :fgate))
104+
ffeat = StatefulLuxLayer{true}(l.ffeat, ps.ffeat, _getstate(st, :ffeat))
105+
m = (; fgate, ffeat)
106+
return GNNlib.global_attention_pool(m, g, x), st
107+
end
108+
109+
(l::GlobalAttentionPool)(g::GNNGraph) = GNNGraph(g, gdata = l(g, node_features(g), ps, st))
110+
111+
"""
112+
TopKPool(adj, k, in_channel)
113+
114+
Top-k pooling layer.
115+
116+
# Arguments
117+
118+
- `adj`: Adjacency matrix of a graph.
119+
- `k`: Top-k nodes are selected to pool together.
120+
- `in_channel`: The dimension of input channel.
121+
"""
122+
struct TopKPool{T, S}
123+
A::AbstractMatrix{T}
124+
k::Int
125+
p::AbstractVector{S}
126+
::AbstractMatrix{T}
127+
end
128+
129+
function TopKPool(adj::AbstractMatrix, k::Int, in_channel::Int; init = glorot_uniform)
130+
TopKPool(adj, k, init(in_channel), similar(adj, k, k))
131+
end
132+
133+
(t::TopKPool)(x::AbstractArray, ps, st) = GNNlib.topk_pool(t, x), st

GNNLux/test/layers/pool.jl

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,46 @@
11
@testitem "Pooling" setup=[TestModuleLux] begin
22
using .TestModuleLux
3-
@testset "GlobalPool" begin
3+
@testset "Pooling" begin
44

55
rng = StableRNG(1234)
6-
g = rand_graph(rng, 10, 40)
7-
in_dims = 3
8-
x = randn(rng, Float32, in_dims, 10)
9-
10-
@testset "GCNConv" begin
6+
@testset "GlobalPool" begin
7+
g = rand_graph(rng, 10, 40)
8+
in_dims = 3
9+
x = randn(rng, Float32, in_dims, 10)
1110
l = GlobalPool(mean)
1211
test_lux_layer(rng, l, g, x, sizey=(in_dims,1))
1312
end
13+
@testset "GlobalAttentionPool" begin
14+
n = 10
15+
chin = 6
16+
chout = 5
17+
ng = 3
18+
g = batch([GNNGraph(rand_graph(rng, 10, 40),
19+
ndata = rand(Float32, chin, n)) for i in 1:ng])
20+
21+
fgate = Dense(chin, 1)
22+
ffeat = Dense(chin, chout)
23+
l = GlobalAttentionPool(fgate, ffeat)
24+
25+
test_lux_layer(rng, l, g, g.ndata.x, sizey=(chout,ng), container=true)
26+
end
27+
28+
@testset "TopKPool" begin
29+
N = 10
30+
k, in_channel = 4, 7
31+
X = rand(in_channel, N)
32+
ps = (;)
33+
st = (;)
34+
for T in [Bool, Float64]
35+
adj = rand(T, N, N)
36+
p = GNNLux.TopKPool(adj, k, in_channel)
37+
@test eltype(p.p) === Float32
38+
@test size(p.p) == (in_channel,)
39+
@test eltype(p.Ã) === T
40+
@test size(p.Ã) == (k, k)
41+
y, st = p(X, ps, st)
42+
@test size(y) == (in_channel, k)
43+
end
44+
end
1445
end
1546
end

GNNlib/src/layers/pool.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ topk_index(y::Adjoint, k::Int) = topk_index(y', k)
2929
function set2set_pool(l, g::GNNGraph, x::AbstractMatrix)
3030
n_in = size(x, 1)
3131
qstar = zeros_like(x, (2*n_in, g.num_graphs))
32-
h = zeros_like(l.lstm.Wh, size(l.lstm.Wh, 2))
33-
c = zeros_like(l.lstm.Wh, size(l.lstm.Wh, 2))
32+
h = zeros_like(l.Wh, size(l.Wh, 2))
33+
c = zeros_like(l.Wh, size(l.Wh, 2))
3434
state = (h, c)
3535
for t in 1:l.num_iters
3636
q, state = l.lstm(qstar, state) # [n_in, n_graphs]

GraphNeuralNetworks/src/layers/pool.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,10 @@ operation:
6161
# Arguments
6262
6363
- `fgate`: The function ``f_{gate}: \mathbb{R}^{D_{in}} \to \mathbb{R}``.
64-
It is tipically expressed by a neural network.
64+
It is typically expressed by a neural network.
6565
6666
- `ffeat`: The function ``f_{feat}: \mathbb{R}^{D_{in}} \to \mathbb{R}^{D_{out}}``.
67-
It is tipically expressed by a neural network.
67+
It is typically expressed by a neural network.
6868
6969
# Examples
7070
@@ -156,7 +156,8 @@ function Set2Set(n_in::Int, n_iters::Int, n_layers::Int = 1)
156156
end
157157

158158
function (l::Set2Set)(g, x)
159-
return GNNlib.set2set_pool(l, g, x)
159+
m = (; l.lstm, l.num_iters, Wh = l.lstm.Wh)
160+
return GNNlib.set2set_pool(m, g, x)
160161
end
161162

162163
(l::Set2Set)(g::GNNGraph) = GNNGraph(g, gdata = l(g, node_features(g)))

0 commit comments

Comments
 (0)