Skip to content

Commit 28b3645

Browse files
add Set2Set (#276)
* add Set2Set * cleanup * cleanup
1 parent 05fca7c commit 28b3645

File tree

5 files changed

+79
-0
lines changed

5 files changed

+79
-0
lines changed

src/GNNGraphs/transform.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -729,3 +729,4 @@ ci2t(ci::AbstractVector{<:CartesianIndex}, dims) = ntuple(i -> map(x -> x[i], ci
729729
@non_differentiable negative_sample(x...)
730730
@non_differentiable add_self_loops(x...) # TODO this is wrong, since g carries feature arrays, needs rrule
731731
@non_differentiable remove_self_loops(x...) # TODO this is wrong, since g carries feature arrays, needs rrule
732+
@non_differentiable dense_zeros_like(x...)

src/GraphNeuralNetworks.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ using CUDA
77
using Flux
88
using Flux: glorot_uniform, leakyrelu, GRUCell, @functor, batch
99
using MacroTools: @forward
10+
using MLUtils
1011
using NNlib, NNlibCUDA
1112
using NNlib: scatter, gather
1213
using ChainRulesCore
@@ -69,6 +70,7 @@ export
6970
# layers/pool
7071
GlobalPool,
7172
GlobalAttentionPool,
73+
Set2Set,
7274
TopKPool,
7375
topk_index,
7476

src/layers/pool.jl

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,3 +143,62 @@ function topk_index(y::AbstractVector, k::Int)
143143
end
144144

145145
topk_index(y::Adjoint, k::Int) = topk_index(y', k)
146+
147+
148+
@doc raw"""
149+
Set2Set(n_in, n_iters, n_layers = 1)
150+
151+
Set2Set layer from the paper [Order Matters: Sequence to sequence for sets](https://arxiv.org/abs/1511.06391).
152+
153+
For each graph in the batch, the layer computes an output vector of size `2*n_in` by iterating the following steps `n_iters` times:
154+
```math
155+
\mathbf{q} = \mathrm{LSTM}(\mathbf{q}_{t-1}^*)
156+
\alpha_{i} = \frac{\exp(\mathbf{q}^T \mathbf{x}_i)}{\sum_{j=1}^N \exp(\mathbf{q}^T \mathbf{x}_j)}
157+
\mathbf{r} = \sum_{i=1}^N \alpha_{i} \mathbf{x}_i
158+
\mathbf{q}^*_t = [\mathbf{q}; \mathbf{r}]
159+
```
160+
where `N` is the number of nodes in the graph, `LSTM` is a Long-Short-Term-Memory network with `n_layers` layers,
161+
input size `2*n_in` and output size `n_in`.
162+
163+
Given a batch of graphs `g` and node features `x`, the layer returns a matrix of size `(2*n_in, n_graphs)`.
164+
```
165+
"""
166+
struct Set2Set{L} <: GNNLayer
167+
lstm::L
168+
num_iters::Int
169+
end
170+
171+
@functor Set2Set
172+
173+
function Set2Set(n_in::Int, n_iters::Int, n_layers::Int = 1)
174+
@assert n_layers >= 1
175+
n_out = 2 * n_in
176+
177+
if n_layers == 1
178+
lstm = LSTM(n_out => n_in)
179+
else
180+
layers = [LSTM(n_out => n_in)]
181+
for _ in 2:n_layers
182+
push!(layers, LSTM(n_in => n_in))
183+
end
184+
lstm = Chain(layers...)
185+
end
186+
187+
return Set2Set(lstm, n_iters)
188+
end
189+
190+
function (l::Set2Set)(g::GNNGraph, x::AbstractMatrix)
191+
n_in = size(x, 1)
192+
Flux.reset!(l.lstm)
193+
qstar = zeros_like(x, (2*n_in, g.num_graphs))
194+
for t in 1:l.num_iters
195+
q = l.lstm(qstar) # [n_in, n_graphs]
196+
qn = broadcast_nodes(g, q) # [n_in, n_nodes]
197+
α = softmax_nodes(g, sum(qn .* x, dims = 1)) # [1, n_nodes]
198+
r = reduce_nodes(+, g, x .* α) # [n_in, n_graphs]
199+
qstar = vcat(q, r) # [2*n_in, n_graphs]
200+
end
201+
return qstar
202+
end
203+
204+
(l::Set2Set)(g::GNNGraph) = GNNGraph(g, gdata = l(g, node_features(g)))

test/layers/pool.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,18 @@ end
6060
@test topk_index(X, 4) == [1, 2, 3, 4]
6161
@test topk_index(X', 4) == [1, 2, 3, 4]
6262
end
63+
64+
@testset "Set2Set" begin
65+
n_in = 3
66+
n_iters = 2
67+
n_layers = 1
68+
g = batch([rand_graph(10, 40, graph_type = GRAPH_T) for _ in 1:5])
69+
g = GNNGraph(g, ndata = rand(Float32, n_in, g.num_nodes))
70+
l = Set2Set(n_in, n_iters, n_layers)
71+
y = l(g, node_features(g))
72+
@test size(y) == (2 * n_in, g.num_graphs)
73+
74+
## TODO the numerical gradient seems to be 3 times smaller than zygote one
75+
# test_layer(l, g, rtol = 1e-4, atol=1e-4, outtype = :graph, outsize = (2 * n_in, g.num_graphs),
76+
# verbose=true, exclude_grad_fields = [:state0, :state])
77+
end

test/test_utils.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,8 +180,10 @@ function test_approx_structs(l, l̄, l̄fd; atol = 1e-5, rtol = 1e-5,
180180

181181
for f in fieldnames(typeof(l))
182182
f exclude_grad_fields && continue
183+
verbose && println("Test gradient of field $f...")
183184
x, g, gfd = getfield(l, f), getfield(l̄, f), getfield(l̄fd, f)
184185
test_approx_structs(x, g, gfd; atol, rtol, exclude_grad_fields, verbose)
186+
verbose && println("... field $f done!")
185187
end
186188
return true
187189
end

0 commit comments

Comments
 (0)