Skip to content

Commit bfa25ce

Browse files
fixes
1 parent c473c9e commit bfa25ce

File tree

4 files changed

+70
-31
lines changed

4 files changed

+70
-31
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d"
1919
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2020
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2121
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
22+
TestEnv = "1e6cf692-eddd-4d53-88a5-2d735e33781b"
2223

2324
[compat]
2425
Adapt = "3"
@@ -33,6 +34,7 @@ MacroTools = "0.5"
3334
NNlib = "0.7"
3435
NNlibCUDA = "0.1"
3536
julia = "1.6"
37+
TestEnv = "1"
3638

3739
[extras]
3840
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/layers/pool.jl

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,15 @@ and performs the operation
1010
```math
1111
\mathbf{u}_V = \square_{i \in V} \mathbf{x}_i
1212
```
13+
1314
where ``V`` is the set of nodes of the input graph and
1415
the type of aggregation represented by ``\square`` is selected by the `aggr` argument.
1516
Commonly used aggregations are `mean`, `max`, and `+`.
1617
1718
See also [`reduce_nodes`](@ref).
1819
1920
# Examples
21+
2022
```julia
2123
using Flux, GraphNeuralNetworks, Graphs
2224
@@ -50,18 +52,42 @@ Global soft attention layer from the [Gated Graph Sequence Neural
5052
Networks](https://arxiv.org/abs/1511.05493) paper
5153
5254
```math
53-
\mathbf{u}_V} = \sum_{i\in V} \mathrm{softmax} \left(
54-
f_{\mathrm{gate}} ( \mathbf{x}_i ) \right) \odot
55-
f_{\mathrm{feat}} ( \mathbf{x}_i ),
55+
\mathbf{u}_V} = \sum_{i\in V} \alpha_i\, f_{\mathrm{feat}}(\mathbf{x}_i)
5656
```
5757
58-
where ``f_{\mathrm{gate}} \colon \mathbb{R}^F \to
59-
\mathbb{R}`` and ``f_{\mathbf{feat}}` denote neural networks.
58+
where the coefficients ``alpha_i`` are given by a [`softmax_nodes`](@ref)
59+
operation:
60+
61+
```math
62+
\alpha_i = \frac{e^{f_{\mathrm{feat}}(\mathbf{x}_i)}}
63+
{\sum_{i'\in V} e^{f_{\mathrm{feat}}(\mathbf{x}_{i'})}}.
64+
```
6065
6166
# Arguments
6267
63-
fgate:
64-
ffeat:
68+
- `fgate`: The function ``f_{\mathrm{gate}} \colon \mathbb{R}^{D_{in}} \to
69+
\mathbb{R}``. It is tipically a neural network.
70+
71+
- `ffeat`: The function ``f_{\mathrm{feat}} \colon \mathbb{R}^{D_{in}} \to
72+
\mathbb{R}^{D_{out}}``. It is tipically a neural network.
73+
74+
# Examples
75+
76+
```julia
77+
chin = 6
78+
chout = 5
79+
80+
fgate = Dense(chin, 1)
81+
ffeat = Dense(chin, chout)
82+
pool = GlobalAttentionPool(fgate, ffeat)
83+
84+
g = Flux.batch([GNNGraph(random_regular_graph(10, 4),
85+
ndata=rand(Float32, chin, 10))
86+
for i=1:3])
87+
88+
u = pool(g, g.ndata.x)
89+
90+
@assert size(u) == (chout, g.num_graphs)
6591
"""
6692
struct GlobalAttentionPool{G,F}
6793
fgate::G
@@ -72,11 +98,10 @@ end
7298

7399
GlobalAttentionPool(fgate) = GlobalAttentionPool(fgate, identity)
74100

75-
76101
function (l::GlobalAttentionPool)(g::GNNGraph, x::AbstractArray)
77-
weights = softmax_nodes(g, l.fgate(x))
78-
feats = l.ffeat(x)
79-
u = reduce_nodes(+, g, weights .* feats)
102+
α = softmax_nodes(g, l.fgate(x))
103+
feats = α .* l.ffeat(x)
104+
u = reduce_nodes(+, g, feats)
80105
return u
81106
end
82107

@@ -101,9 +126,6 @@ struct TopKPool{T,S}
101126
::AbstractMatrix{T}
102127
end
103128

104-
105-
106-
107129
function TopKPool(adj::AbstractMatrix, k::Int, in_channel::Int; init=glorot_uniform)
108130
TopKPool(adj, k, init(in_channel), similar(adj, k, k))
109131
end

test/layers/pool.jl

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,14 @@
44
n = 10
55
chin = 6
66
X = rand(Float32, 6, n)
7-
g = GNNGraph(random_regular_graph(n, 4), ndata=X)
7+
g = GNNGraph(random_regular_graph(n, 4), ndata=X, graph_type=GRAPH_T)
88
u = p(g, X)
99
@test u sum(X, dims=2)
1010

1111
ng = 3
1212
g = Flux.batch([GNNGraph(random_regular_graph(n, 4),
13-
ndata=rand(Float32, chin, n))
13+
ndata=rand(Float32, chin, n),
14+
graph_type=GRAPH_T)
1415
for i=1:ng])
1516
u = p(g, g.ndata.x)
1617
@test size(u) == (chin, ng)
@@ -22,13 +23,21 @@
2223

2324
@testset "GlobalAttentionPool" begin
2425
n = 10
25-
chin = 16
26-
X = rand(Float32, chin, n)
27-
g = GNNGraph(random_regular_graph(n, 4), ndata=X)
28-
fgate = Dense(chin, 1, sigmoid)
29-
p = GlobalAttentionPool(fgate)
30-
y = p(g, X)
31-
test_layer(p, g, rtol=1e-5, outtype=:graph)
26+
chin = 6
27+
chout = 5
28+
ng = 3
29+
30+
fgate = Dense(chin, 1)
31+
ffeat = Dense(chin, chout)
32+
p = GlobalAttentionPool(fgate, ffeat)
33+
@test length(Flux.params(p)) == 4
34+
35+
g = Flux.batch([GNNGraph(random_regular_graph(n, 4),
36+
ndata=rand(Float32, chin, n),
37+
graph_type=GRAPH_T)
38+
for i=1:ng])
39+
40+
test_layer(p, g, rtol=1e-5, outtype=:graph, outsize=(chout, ng))
3241
end
3342

3443

test/test_utils.jl

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,21 @@ function test_layer(l, g::GNNGraph; atol = 1e-6, rtol = 1e-5,
3434
x64, e64, l64, g64 = to64.([x, e, l, g]) # needed for accurate FiniteDifferences' grad
3535
xgpu, egpu, lgpu, ggpu = gpu.([x, e, l, g])
3636

37-
f(l, g) = l(g)
38-
f(l, g, x::AbstractArray{Float32}) = isnothing(e) ? l(g, x) : l(g, x, e)
39-
f(l, g, x::AbstractArray{Float64}) = isnothing(e64) ? l(g, x) : l(g, x, e64)
40-
f(l, g, x::CuArray) = isnothing(e64) ? l(g, x) : l(g, x, egpu)
37+
f(l, g::GNNGraph) = l(g)
38+
f(l, g::GNNGraph, x::AbstractArray{Float32}) = isnothing(e) ? l(g, x) : l(g, x, e)
39+
f(l, g::GNNGraph, x::AbstractArray{Float64}) = isnothing(e64) ? l(g, x) : l(g, x, e64)
40+
f(l, g::GNNGraph, x::CuArray) = isnothing(e64) ? l(g, x) : l(g, x, egpu)
4141

42-
loss(l, g) = sum(node_features(f(l, g)))
43-
loss(l, g, x) = sum(f(l, g, x))
44-
loss(l, g, x, e) = sum(l(g, x, e))
42+
loss(l, g::GNNGraph) = if outtype == :node
43+
sum(node_features(f(l, g)))
44+
elseif outtype == :edge
45+
sum(edge_features(f(l, g)))
46+
elseif outtype == :graph
47+
sum(graph_features(f(l, g)))
48+
end
49+
50+
loss(l, g::GNNGraph, x) = sum(f(l, g, x))
51+
loss(l, g::GNNGraph, x, e) = sum(l(g, x, e))
4552

4653

4754
# TEST OUTPUT
@@ -117,7 +124,6 @@ function test_layer(l, g::GNNGraph; atol = 1e-6, rtol = 1e-5,
117124

118125
# TEST LAYER GRADIENT - l(g)
119126
= gradient(l -> loss(l, g), l)[1]
120-
l̄_fd = FiniteDifferences.grad(fdm, l64 -> loss(l64, g64), l64)[1]
121127
test_approx_structs(l, l̄, l̄_fd; atol, rtol, broken_grad_fields, exclude_grad_fields, verbose)
122128

123129
return true

0 commit comments

Comments
 (0)