Skip to content

Commit e0a5214

Browse files
Merge pull request #58 from CarloLucibello/cl/ga
add GlobalAttentionPool
2 parents 205cc0d + bfa25ce commit e0a5214

File tree

5 files changed

+120
-13
lines changed

5 files changed

+120
-13
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d"
1919
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2020
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2121
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
22+
TestEnv = "1e6cf692-eddd-4d53-88a5-2d735e33781b"
2223

2324
[compat]
2425
Adapt = "3"
@@ -33,6 +34,7 @@ MacroTools = "0.5"
3334
NNlib = "0.7"
3435
NNlibCUDA = "0.1"
3536
julia = "1.6"
37+
TestEnv = "1"
3638

3739
[extras]
3840
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/GraphNeuralNetworks.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ export
6060

6161
# layers/pool
6262
GlobalPool,
63+
GlobalAttentionPool,
6364
TopKPool,
6465
topk_index
6566

src/layers/pool.jl

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,15 @@ and performs the operation
1010
```math
1111
\mathbf{u}_V = \square_{i \in V} \mathbf{x}_i
1212
```
13+
1314
where ``V`` is the set of nodes of the input graph and
1415
the type of aggregation represented by ``\square`` is selected by the `aggr` argument.
1516
Commonly used aggregations are `mean`, `max`, and `+`.
1617
1718
See also [`reduce_nodes`](@ref).
1819
1920
# Examples
21+
2022
```julia
2123
using Flux, GraphNeuralNetworks, Graphs
2224
@@ -42,6 +44,70 @@ end
4244

4345
(l::GlobalPool)(g::GNNGraph) = GNNGraph(g, gdata=l(g, node_features(g)))
4446

47+
48+
@doc raw"""
49+
GlobalAttentionPool(fgate, ffeat=identity)
50+
51+
Global soft attention layer from the [Gated Graph Sequence Neural
52+
Networks](https://arxiv.org/abs/1511.05493) paper
53+
54+
```math
55+
\mathbf{u}_V} = \sum_{i\in V} \alpha_i\, f_{\mathrm{feat}}(\mathbf{x}_i)
56+
```
57+
58+
where the coefficients ``alpha_i`` are given by a [`softmax_nodes`](@ref)
59+
operation:
60+
61+
```math
62+
\alpha_i = \frac{e^{f_{\mathrm{feat}}(\mathbf{x}_i)}}
63+
{\sum_{i'\in V} e^{f_{\mathrm{feat}}(\mathbf{x}_{i'})}}.
64+
```
65+
66+
# Arguments
67+
68+
- `fgate`: The function ``f_{\mathrm{gate}} \colon \mathbb{R}^{D_{in}} \to
69+
\mathbb{R}``. It is tipically a neural network.
70+
71+
- `ffeat`: The function ``f_{\mathrm{feat}} \colon \mathbb{R}^{D_{in}} \to
72+
\mathbb{R}^{D_{out}}``. It is tipically a neural network.
73+
74+
# Examples
75+
76+
```julia
77+
chin = 6
78+
chout = 5
79+
80+
fgate = Dense(chin, 1)
81+
ffeat = Dense(chin, chout)
82+
pool = GlobalAttentionPool(fgate, ffeat)
83+
84+
g = Flux.batch([GNNGraph(random_regular_graph(10, 4),
85+
ndata=rand(Float32, chin, 10))
86+
for i=1:3])
87+
88+
u = pool(g, g.ndata.x)
89+
90+
@assert size(u) == (chout, g.num_graphs)
91+
"""
92+
struct GlobalAttentionPool{G,F}
93+
fgate::G
94+
ffeat::F
95+
end
96+
97+
@functor GlobalAttentionPool
98+
99+
GlobalAttentionPool(fgate) = GlobalAttentionPool(fgate, identity)
100+
101+
function (l::GlobalAttentionPool)(g::GNNGraph, x::AbstractArray)
102+
α = softmax_nodes(g, l.fgate(x))
103+
feats = α .* l.ffeat(x)
104+
u = reduce_nodes(+, g, feats)
105+
return u
106+
end
107+
108+
(l::GlobalAttentionPool)(g::GNNGraph) = GNNGraph(g, gdata=l(g, node_features(g)))
109+
110+
45111
"""
46112
TopKPool(adj, k, in_channel)
47113

test/layers/pool.jl

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,46 @@
11
@testset "pool" begin
22
@testset "GlobalPool" begin
3-
n = 10
4-
X = rand(Float32, 16, n)
5-
g = GNNGraph(random_regular_graph(n, 4), ndata=X)
63
p = GlobalPool(+)
7-
y = p(g, X)
8-
@test y NNlib.scatter(+, X, ones(Int, n))
4+
n = 10
5+
chin = 6
6+
X = rand(Float32, 6, n)
7+
g = GNNGraph(random_regular_graph(n, 4), ndata=X, graph_type=GRAPH_T)
8+
u = p(g, X)
9+
@test u sum(X, dims=2)
10+
11+
ng = 3
12+
g = Flux.batch([GNNGraph(random_regular_graph(n, 4),
13+
ndata=rand(Float32, chin, n),
14+
graph_type=GRAPH_T)
15+
for i=1:ng])
16+
u = p(g, g.ndata.x)
17+
@test size(u) == (chin, ng)
18+
@test u[:,[1]] sum(g.ndata.x[:,1:n], dims=2)
19+
@test p(g).gdata.u == u
20+
921
test_layer(p, g, rtol=1e-5, exclude_grad_fields = [:aggr], outtype=:graph)
1022
end
1123

24+
@testset "GlobalAttentionPool" begin
25+
n = 10
26+
chin = 6
27+
chout = 5
28+
ng = 3
29+
30+
fgate = Dense(chin, 1)
31+
ffeat = Dense(chin, chout)
32+
p = GlobalAttentionPool(fgate, ffeat)
33+
@test length(Flux.params(p)) == 4
34+
35+
g = Flux.batch([GNNGraph(random_regular_graph(n, 4),
36+
ndata=rand(Float32, chin, n),
37+
graph_type=GRAPH_T)
38+
for i=1:ng])
39+
40+
test_layer(p, g, rtol=1e-5, outtype=:graph, outsize=(chout, ng))
41+
end
42+
43+
1244
@testset "TopKPool" begin
1345
N = 10
1446
k, in_channel = 4, 7

test/test_utils.jl

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,21 @@ function test_layer(l, g::GNNGraph; atol = 1e-6, rtol = 1e-5,
3434
x64, e64, l64, g64 = to64.([x, e, l, g]) # needed for accurate FiniteDifferences' grad
3535
xgpu, egpu, lgpu, ggpu = gpu.([x, e, l, g])
3636

37-
f(l, g) = l(g)
38-
f(l, g, x::AbstractArray{Float32}) = isnothing(e) ? l(g, x) : l(g, x, e)
39-
f(l, g, x::AbstractArray{Float64}) = isnothing(e64) ? l(g, x) : l(g, x, e64)
40-
f(l, g, x::CuArray) = isnothing(e64) ? l(g, x) : l(g, x, egpu)
37+
f(l, g::GNNGraph) = l(g)
38+
f(l, g::GNNGraph, x::AbstractArray{Float32}) = isnothing(e) ? l(g, x) : l(g, x, e)
39+
f(l, g::GNNGraph, x::AbstractArray{Float64}) = isnothing(e64) ? l(g, x) : l(g, x, e64)
40+
f(l, g::GNNGraph, x::CuArray) = isnothing(e64) ? l(g, x) : l(g, x, egpu)
4141

42-
loss(l, g) = sum(node_features(f(l, g)))
43-
loss(l, g, x) = sum(f(l, g, x))
44-
loss(l, g, x, e) = sum(l(g, x, e))
42+
loss(l, g::GNNGraph) = if outtype == :node
43+
sum(node_features(f(l, g)))
44+
elseif outtype == :edge
45+
sum(edge_features(f(l, g)))
46+
elseif outtype == :graph
47+
sum(graph_features(f(l, g)))
48+
end
49+
50+
loss(l, g::GNNGraph, x) = sum(f(l, g, x))
51+
loss(l, g::GNNGraph, x, e) = sum(l(g, x, e))
4552

4653

4754
# TEST OUTPUT
@@ -117,7 +124,6 @@ function test_layer(l, g::GNNGraph; atol = 1e-6, rtol = 1e-5,
117124

118125
# TEST LAYER GRADIENT - l(g)
119126
= gradient(l -> loss(l, g), l)[1]
120-
l̄_fd = FiniteDifferences.grad(fdm, l64 -> loss(l64, g64), l64)[1]
121127
test_approx_structs(l, l̄, l̄_fd; atol, rtol, broken_grad_fields, exclude_grad_fields, verbose)
122128

123129
return true

0 commit comments

Comments
 (0)