Skip to content

Commit c473c9e

Browse files
add GlobalAttentionPool
1 parent 205cc0d commit c473c9e

File tree

3 files changed

+72
-4
lines changed

3 files changed

+72
-4
lines changed

src/GraphNeuralNetworks.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ export
6060

6161
# layers/pool
6262
GlobalPool,
63+
GlobalAttentionPool,
6364
TopKPool,
6465
topk_index
6566

src/layers/pool.jl

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,47 @@ end
4242

4343
(l::GlobalPool)(g::GNNGraph) = GNNGraph(g, gdata=l(g, node_features(g)))
4444

45+
46+
@doc raw"""
47+
GlobalAttentionPool(fgate, ffeat=identity)
48+
49+
Global soft attention layer from the [Gated Graph Sequence Neural
50+
Networks](https://arxiv.org/abs/1511.05493) paper
51+
52+
```math
53+
\mathbf{u}_V} = \sum_{i\in V} \mathrm{softmax} \left(
54+
f_{\mathrm{gate}} ( \mathbf{x}_i ) \right) \odot
55+
f_{\mathrm{feat}} ( \mathbf{x}_i ),
56+
```
57+
58+
where ``f_{\mathrm{gate}} \colon \mathbb{R}^F \to
59+
\mathbb{R}`` and ``f_{\mathbf{feat}}` denote neural networks.
60+
61+
# Arguments
62+
63+
fgate:
64+
ffeat:
65+
"""
66+
struct GlobalAttentionPool{G,F}
67+
fgate::G
68+
ffeat::F
69+
end
70+
71+
@functor GlobalAttentionPool
72+
73+
GlobalAttentionPool(fgate) = GlobalAttentionPool(fgate, identity)
74+
75+
76+
function (l::GlobalAttentionPool)(g::GNNGraph, x::AbstractArray)
77+
weights = softmax_nodes(g, l.fgate(x))
78+
feats = l.ffeat(x)
79+
u = reduce_nodes(+, g, weights .* feats)
80+
return u
81+
end
82+
83+
(l::GlobalAttentionPool)(g::GNNGraph) = GNNGraph(g, gdata=l(g, node_features(g)))
84+
85+
4586
"""
4687
TopKPool(adj, k, in_channel)
4788
@@ -60,6 +101,9 @@ struct TopKPool{T,S}
60101
::AbstractMatrix{T}
61102
end
62103

104+
105+
106+
63107
function TopKPool(adj::AbstractMatrix, k::Int, in_channel::Int; init=glorot_uniform)
64108
TopKPool(adj, k, init(in_channel), similar(adj, k, k))
65109
end

test/layers/pool.jl

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,37 @@
11
@testset "pool" begin
22
@testset "GlobalPool" begin
3+
p = GlobalPool(+)
34
n = 10
4-
X = rand(Float32, 16, n)
5+
chin = 6
6+
X = rand(Float32, 6, n)
57
g = GNNGraph(random_regular_graph(n, 4), ndata=X)
6-
p = GlobalPool(+)
7-
y = p(g, X)
8-
@test y NNlib.scatter(+, X, ones(Int, n))
8+
u = p(g, X)
9+
@test u sum(X, dims=2)
10+
11+
ng = 3
12+
g = Flux.batch([GNNGraph(random_regular_graph(n, 4),
13+
ndata=rand(Float32, chin, n))
14+
for i=1:ng])
15+
u = p(g, g.ndata.x)
16+
@test size(u) == (chin, ng)
17+
@test u[:,[1]] sum(g.ndata.x[:,1:n], dims=2)
18+
@test p(g).gdata.u == u
19+
920
test_layer(p, g, rtol=1e-5, exclude_grad_fields = [:aggr], outtype=:graph)
1021
end
1122

23+
@testset "GlobalAttentionPool" begin
24+
n = 10
25+
chin = 16
26+
X = rand(Float32, chin, n)
27+
g = GNNGraph(random_regular_graph(n, 4), ndata=X)
28+
fgate = Dense(chin, 1, sigmoid)
29+
p = GlobalAttentionPool(fgate)
30+
y = p(g, X)
31+
test_layer(p, g, rtol=1e-5, outtype=:graph)
32+
end
33+
34+
1235
@testset "TopKPool" begin
1336
N = 10
1437
k, in_channel = 4, 7

0 commit comments

Comments
 (0)