Skip to content

Commit a173306

Browse files
committed
Add GlobalAttentionPool
1 parent 24eae1f commit a173306

File tree

4 files changed

+81
-7
lines changed

4 files changed

+81
-7
lines changed

GNNLux/src/GNNLux.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ export TGCN,
4949
EvolveGCNO
5050

5151
include("layers/pool.jl")
52-
export GlobalPool
52+
export GlobalPool,
53+
GlobalAttentionPool
5354

5455
end #module
5556

GNNLux/src/layers/pool.jl

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,4 +39,71 @@ end
3939

4040
(l::GlobalPool)(g::GNNGraph, x::AbstractArray, ps, st) = GNNlib.global_pool(l, g, x), st
4141

42-
(l::GlobalPool)(g::GNNGraph) = GNNGraph(g, gdata = l(g, node_features(g), ps, st))
42+
(l::GlobalPool)(g::GNNGraph) = GNNGraph(g, gdata = l(g, node_features(g), ps, st))
43+
44+
@doc raw"""
45+
GlobalAttentionPool(fgate, ffeat=identity)
46+
47+
Global soft attention layer from the [Gated Graph Sequence Neural
48+
Networks](https://arxiv.org/abs/1511.05493) paper
49+
50+
```math
51+
\mathbf{u}_V = \sum_{i\in V} \alpha_i\, f_{feat}(\mathbf{x}_i)
52+
```
53+
54+
where the coefficients ``\alpha_i`` are given by a [`softmax_nodes`](@ref)
55+
operation:
56+
57+
```math
58+
\alpha_i = \frac{e^{f_{gate}(\mathbf{x}_i)}}
59+
{\sum_{i'\in V} e^{f_{gate}(\mathbf{x}_{i'})}}.
60+
```
61+
62+
# Arguments
63+
64+
- `fgate`: The function ``f_{gate}: \mathbb{R}^{D_{in}} \to \mathbb{R}``.
65+
It is typically expressed by a neural network.
66+
67+
- `ffeat`: The function ``f_{feat}: \mathbb{R}^{D_{in}} \to \mathbb{R}^{D_{out}}``.
68+
It is typically expressed by a neural network.
69+
70+
# Examples
71+
72+
```julia
73+
using Graphs, LuxCore, Lux, GNNLux, Random
74+
75+
rng = Random.default_rng()
76+
chin = 6
77+
chout = 5
78+
79+
fgate = Dense(chin, 1)
80+
ffeat = Dense(chin, chout)
81+
pool = GlobalAttentionPool(fgate, ffeat)
82+
83+
g = batch([GNNGraph(Graphs.random_regular_graph(10, 4),
84+
ndata=rand(Float32, chin, 10))
85+
for i=1:3])
86+
87+
ps = (fgate = LuxCore.initialparameters(rng, fgate), ffeat = LuxCore.initialparameters(rng, ffeat))
88+
st = (fgate = LuxCore.initialstates(rng, fgate), ffeat = LuxCore.initialstates(rng, ffeat))
89+
90+
u, st = pool(g, g.ndata.x, ps, st)
91+
92+
@assert size(u) == (chout, g.num_graphs)
93+
```
94+
"""
95+
struct GlobalAttentionPool{G, F}
96+
fgate::G
97+
ffeat::F
98+
end
99+
100+
GlobalAttentionPool(fgate) = GlobalAttentionPool(fgate, identity)
101+
102+
function (l::GlobalAttentionPool)(g, x, ps, st)
103+
fgate = StatefulLuxLayer{true}(l.fgate, ps.fgate, _getstate(st, :fgate))
104+
ffeat = StatefulLuxLayer{true}(l.ffeat, ps.ffeat, _getstate(st, :ffeat))
105+
m = (; fgate, ffeat)
106+
return GNNlib.global_attention_pool(m, g, x), st
107+
end
108+
109+
(l::GlobalAttentionPool)(g::GNNGraph) = GNNGraph(g, gdata = l(g, node_features(g), ps, st))

GNNLux/test/layers/pool.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,21 @@
11
@testitem "Pooling" setup=[TestModuleLux] begin
22
using .TestModuleLux
3-
@testset "GlobalPool" begin
3+
@testset "Pooling" begin
44

55
rng = StableRNG(1234)
66
g = rand_graph(rng, 10, 40)
77
in_dims = 3
88
x = randn(rng, Float32, in_dims, 10)
99

10-
@testset "GCNConv" begin
11-
l = GlobalPool(mean)
10+
@testset "GlobalPool" begin
11+
l = GNNLux.GlobalPool(mean)
1212
test_lux_layer(rng, l, g, x, sizey=(in_dims,1))
1313
end
14+
@testset "GlobalAttentionPool" begin
15+
fgate = Dense(in_dims, 1)
16+
ffeat = Dense(in_dims, in_dims)
17+
l = GNNLux.GlobalAttentionPool(fgate, ffeat)
18+
test_lux_layer(rng, l, g, x, sizey=(in_dims,1), container=true)
19+
end
1420
end
1521
end

GraphNeuralNetworks/src/layers/pool.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,10 @@ operation:
6161
# Arguments
6262
6363
- `fgate`: The function ``f_{gate}: \mathbb{R}^{D_{in}} \to \mathbb{R}``.
64-
It is tipically expressed by a neural network.
64+
It is typically expressed by a neural network.
6565
6666
- `ffeat`: The function ``f_{feat}: \mathbb{R}^{D_{in}} \to \mathbb{R}^{D_{out}}``.
67-
It is tipically expressed by a neural network.
67+
It is typically expressed by a neural network.
6868
6969
# Examples
7070

0 commit comments

Comments
 (0)