Skip to content

Commit 93a2d53

Browse files
committed
Add GlobalPool
1 parent 7e07564 commit 93a2d53

File tree

2 files changed

+45
-0
lines changed

2 files changed

+45
-0
lines changed

GNNLux/src/GNNLux.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,5 +48,8 @@ export TGCN,
4848
DCGRU,
4949
EvolveGCNO
5050

51+
include("layers/pool.jl")
52+
export GlobalPool
53+
5154
end #module
5255

GNNLux/src/layers/pool.jl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
@doc raw"""
2+
GlobalPool(aggr)
3+
4+
Global pooling layer for graph neural networks.
5+
Takes a graph and feature nodes as inputs
6+
and performs the operation
7+
8+
```math
9+
\mathbf{u}_V = \square_{i \in V} \mathbf{x}_i
10+
```
11+
12+
where ``V`` is the set of nodes of the input graph and
13+
the type of aggregation represented by ``\square`` is selected by the `aggr` argument.
14+
Commonly used aggregations are `mean`, `max`, and `+`.
15+
16+
See also [`reduce_nodes`](@ref).
17+
18+
# Examples
19+
20+
```julia
21+
using Lux, GNNLux, Graphs, MLUtils
22+
23+
using Graphs
24+
pool = GlobalPool(mean)
25+
26+
g = GNNGraph(erdos_renyi(10, 4))
27+
X = rand(32, 10)
28+
pool(g, X) # => 32x1 matrix
29+
30+
31+
g = MLUtils.batch([GNNGraph(erdos_renyi(10, 4)) for _ in 1:5])
32+
X = rand(32, 50)
33+
pool(g, X) # => 32x5 matrix
34+
```
35+
"""
36+
struct GlobalPool{F} <: GNNLayer
37+
aggr::F
38+
end
39+
40+
(l::GlobalPool)(g::GNNGraph, x::AbstractArray, ps, st) = GNNlib.global_pool(l, g, x), st
41+
42+
(l::GlobalPool)(g::GNNGraph) = GNNGraph(g, gdata = l(g, node_features(g)))

0 commit comments

Comments
 (0)