Skip to content

Commit 321dee7

Browse files
committed
Add TopK pooling
1 parent bf94758 commit 321dee7

File tree

2 files changed

+81
-1
lines changed

2 files changed

+81
-1
lines changed

GNNLux/src/layers/pool.jl

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,4 +106,66 @@ function (l::GlobalAttentionPool)(g, x, ps, st)
106106
return GNNlib.global_attention_pool(m, g, x), st
107107
end
108108

109-
(l::GlobalAttentionPool)(g::GNNGraph) = GNNGraph(g, gdata = l(g, node_features(g), ps, st))
109+
(l::GlobalAttentionPool)(g::GNNGraph) = GNNGraph(g, gdata = l(g, node_features(g), ps, st))
110+
111+
"""
112+
TopKPool(adj, k, in_channel)
113+
114+
Top-k pooling layer.
115+
116+
# Arguments
117+
118+
- `adj`: Adjacency matrix of a graph.
119+
- `k`: Top-k nodes are selected to pool together.
120+
- `in_channel`: The dimension of input channel.
121+
"""
122+
struct TopKPool{T, S}
123+
A::AbstractMatrix{T}
124+
k::Int
125+
p::AbstractVector{S}
126+
::AbstractMatrix{T}
127+
end
128+
129+
function TopKPool(adj::AbstractMatrix, k::Int, in_channel::Int; init = glorot_uniform)
130+
TopKPool(adj, k, init(in_channel), similar(adj, k, k))
131+
end
132+
133+
(t::TopKPool)(x::AbstractArray, ps, st) = GNNlib.topk_pool(t, x)
134+
135+
136+
@doc raw"""
137+
Set2Set(n_in, n_iters, n_layers = 1)
138+
139+
Set2Set layer from the paper [Order Matters: Sequence to sequence for sets](https://arxiv.org/abs/1511.06391).
140+
141+
For each graph in the batch, the layer computes an output vector of size `2*n_in` by iterating the following steps `n_iters` times:
142+
```math
143+
\mathbf{q} = \mathrm{LSTM}(\mathbf{q}_{t-1}^*)
144+
\alpha_{i} = \frac{\exp(\mathbf{q}^T \mathbf{x}_i)}{\sum_{j=1}^N \exp(\mathbf{q}^T \mathbf{x}_j)}
145+
\mathbf{r} = \sum_{i=1}^N \alpha_{i} \mathbf{x}_i
146+
\mathbf{q}^*_t = [\mathbf{q}; \mathbf{r}]
147+
```
148+
where `N` is the number of nodes in the graph, `LSTM` is a Long-Short-Term-Memory network with `n_layers` layers, input size `2*n_in` and output size `n_in`.
149+
150+
Given a batch of graphs `g` and node features `x`, the layer returns a matrix of size `(2*n_in, n_graphs)`.
151+
```
152+
"""
153+
struct Set2Set{L} <: GNNContainerLayer{(:lstm,)}
154+
lstm::L
155+
num_iters::Int
156+
end
157+
158+
function Set2Set(n_in::Int, n_iters::Int, n_layers::Int = 1)
159+
@assert n_layers == 1 "multiple layers not implemented yet" #TODO
160+
n_out = 2 * n_in
161+
lstm = Lux.LSTMCell(n_out => n_in)
162+
return Set2Set(lstm, n_iters)
163+
end
164+
165+
function (l::Set2Set)(g, x, ps, st)
166+
lstm = StatefulLuxLayer{true}(l.lstm, ps.lstm, _getstate(st, :lstm))
167+
m = (; lstm, Wh = ps.lstm.weight_hh)
168+
return GNNlib.set2set_pool(m, g, x)
169+
end
170+
171+
(l::Set2Set)(g::GNNGraph) = GNNGraph(g, gdata = l(g, node_features(g), ps, st))

GNNLux/test/layers/pool.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,5 +24,23 @@
2424

2525
test_lux_layer(rng, l, g, g.ndata.x, sizey=(chout,ng), container=true)
2626
end
27+
28+
@testset "TopKPool" begin
29+
N = 10
30+
k, in_channel = 4, 7
31+
X = rand(in_channel, N)
32+
ps = (;)
33+
st = (;)
34+
for T in [Bool, Float64]
35+
adj = rand(T, N, N)
36+
p = GNNLux.TopKPool(adj, k, in_channel)
37+
@test eltype(p.p) === Float32
38+
@test size(p.p) == (in_channel,)
39+
@test eltype(p.Ã) === T
40+
@test size(p.Ã) == (k, k)
41+
y = p(X, ps, st)
42+
@test size(y) == (in_channel, k)
43+
end
44+
end
2745
end
2846
end

0 commit comments

Comments
 (0)