Skip to content

Commit 36e8373

Browse files
authored
Add GConvLSTM temporal layer (#437)
* Add first draft GConvLSTM * Fix spaces * Add show method * Add `GConvLSTM` export * Fix `GConvLSTM` * Add `GConvLSTM` tests * Add `GCLSTM` docstring * Add temporal feat example * Fix missing end
1 parent 942fe91 commit 36e8373

File tree

3 files changed

+136
-1
lines changed

3 files changed

+136
-1
lines changed

src/GraphNeuralNetworks.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ export
7474
# layers/temporalconv
7575
TGCN,
7676
A3TGCN,
77+
GConvLSTM,
7778
GConvGRU,
7879

7980
# layers/pool

src/layers/temporalconv.jl

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,128 @@ Flux.Recur(ggru::GConvGRUCell) = Flux.Recur(ggru, ggru.state0)
279279
_applylayer(l::Flux.Recur{GConvGRUCell}, g::GNNGraph, x) = l(g, x)
280280
_applylayer(l::Flux.Recur{GConvGRUCell}, g::GNNGraph) = l(g)
281281

282+
struct GConvLSTMCell <: GNNLayer
283+
conv_x_i::ChebConv
284+
conv_h_i::ChebConv
285+
w_i
286+
b_i
287+
conv_x_f::ChebConv
288+
conv_h_f::ChebConv
289+
w_f
290+
b_f
291+
conv_x_c::ChebConv
292+
conv_h_c::ChebConv
293+
w_c
294+
b_c
295+
conv_x_o::ChebConv
296+
conv_h_o::ChebConv
297+
w_o
298+
b_o
299+
k::Int
300+
state0
301+
in::Int
302+
out::Int
303+
end
304+
305+
Flux.@functor GConvLSTMCell
306+
307+
function GConvLSTMCell(ch::Pair{Int, Int}, k::Int, n::Int;
308+
bias::Bool = true,
309+
init = Flux.glorot_uniform,
310+
init_state = Flux.zeros32)
311+
in, out = ch
312+
# input gate
313+
conv_x_i = ChebConv(in => out, k; bias, init)
314+
conv_h_i = ChebConv(out => out, k; bias, init)
315+
w_i = init(out, 1)
316+
b_i = bias ? Flux.create_bias(w_i, true, out) : false
317+
# forget gate
318+
conv_x_f = ChebConv(in => out, k; bias, init)
319+
conv_h_f = ChebConv(out => out, k; bias, init)
320+
w_f = init(out, 1)
321+
b_f = bias ? Flux.create_bias(w_f, true, out) : false
322+
# cell state
323+
conv_x_c = ChebConv(in => out, k; bias, init)
324+
conv_h_c = ChebConv(out => out, k; bias, init)
325+
w_c = init(out, 1)
326+
b_c = bias ? Flux.create_bias(w_c, true, out) : false
327+
# output gate
328+
conv_x_o = ChebConv(in => out, k; bias, init)
329+
conv_h_o = ChebConv(out => out, k; bias, init)
330+
w_o = init(out, 1)
331+
b_o = bias ? Flux.create_bias(w_o, true, out) : false
332+
state0 = (init_state(out, n), init_state(out, n))
333+
return GConvLSTMCell(conv_x_i, conv_h_i, w_i, b_i,
334+
conv_x_f, conv_h_f, w_f, b_f,
335+
conv_x_c, conv_h_c, w_c, b_c,
336+
conv_x_o, conv_h_o, w_o, b_o,
337+
k, state0, in, out)
338+
end
339+
340+
function (gclstm::GConvLSTMCell)((h, c), g::GNNGraph, x)
341+
# input gate
342+
i = gclstm.conv_x_i(g, x) .+ gclstm.conv_h_i(g, h) .+ gclstm.w_i .* c .+ gclstm.b_i
343+
i = Flux.sigmoid_fast(i)
344+
# forget gate
345+
f = gclstm.conv_x_f(g, x) .+ gclstm.conv_h_f(g, h) .+ gclstm.w_f .* c .+ gclstm.b_f
346+
f = Flux.sigmoid_fast(f)
347+
# cell state
348+
c = f .* c .+ i .* Flux.tanh_fast(gclstm.conv_x_c(g, x) .+ gclstm.conv_h_c(g, h) .+ gclstm.w_c .* c .+ gclstm.b_c)
349+
# output gate
350+
o = gclstm.conv_x_o(g, x) .+ gclstm.conv_h_o(g, h) .+ gclstm.w_o .* c .+ gclstm.b_o
351+
o = Flux.sigmoid_fast(o)
352+
h = o .* Flux.tanh_fast(c)
353+
return (h,c), h
354+
end
355+
356+
function Base.show(io::IO, gclstm::GConvLSTMCell)
357+
print(io, "GConvLSTMCell($(gclstm.in) => $(gclstm.out))")
358+
end
359+
360+
"""
361+
GConvLSTM(in => out, k, n; [bias, init, init_state])
362+
363+
Graph Convolutional Long Short-Term Memory (GConvLSTM) recurrent layer from the paper [Structured Sequence Modeling with Graph Convolutional Recurrent Networks](https://arxiv.org/pdf/1612.07659).
364+
365+
Performs a layer of ChebConv to model spatial dependencies, followed by a Long Short-Term Memory (LSTM) cell to model temporal dependencies.
366+
367+
# Arguments
368+
369+
- `in`: Number of input features.
370+
- `out`: Number of output features.
371+
- `k`: Chebyshev polynomial order.
372+
- `n`: Number of nodes in the graph.
373+
- `bias`: Add learnable bias. Default `true`.
374+
- `init`: Weights' initializer. Default `glorot_uniform`.
375+
- `init_state`: Initial state of the hidden stat of the LSTM layer. Default `zeros32`.
376+
377+
# Examples
378+
379+
```jldoctest
380+
julia> g1, x1 = rand_graph(5, 10), rand(Float32, 2, 5);
381+
382+
julia> gclstm = GConvLSTM(2 => 5, 2, g1.num_nodes);
383+
384+
julia> y = gclstm(g1, x1);
385+
386+
julia> size(y)
387+
(5, 5)
388+
389+
julia> g2, x2 = rand_graph(5, 10), rand(Float32, 2, 5, 30);
390+
391+
julia> z = gclstm(g2, x2);
392+
393+
julia> size(z)
394+
(5, 5, 30)
395+
```
396+
"""
397+
GConvLSTM(ch, k, n; kwargs...) = Flux.Recur(GConvLSTMCell(ch, k, n; kwargs...))
398+
Flux.Recur(tgcn::GConvLSTMCell) = Flux.Recur(tgcn, tgcn.state0)
399+
400+
(l::Flux.Recur{GConvLSTMCell})(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g)))
401+
_applylayer(l::Flux.Recur{GConvLSTMCell}, g::GNNGraph, x) = l(g, x)
402+
_applylayer(l::Flux.Recur{GConvLSTMCell}, g::GNNGraph) = l(g)
403+
282404
function (l::GINConv)(tg::TemporalSnapshotsGNNGraph, x::AbstractVector)
283405
return l.(tg.snapshots, x)
284406
end

test/layers/temporalconv.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,19 @@ end
3434
@test model(g1) isa GNNGraph
3535
end
3636

37+
@testset "GConvLSTMCell" begin
38+
gconvlstm = GraphNeuralNetworks.GConvLSTMCell(in_channel => out_channel, 2, g1.num_nodes)
39+
(h, c), h = gconvlstm(gconvlstm.state0, g1, g1.ndata.x)
40+
@test size(h) == (out_channel, N)
41+
@test size(c) == (out_channel, N)
42+
end
43+
44+
@testset "GConvLSTM" begin
45+
gconvlstm = GConvLSTM(in_channel => out_channel, 2, g1.num_nodes)
46+
@test size(Flux.gradient(x -> sum(gconvlstm(g1, x)), g1.ndata.x)[1]) == (in_channel, N)
47+
model = GNNChain(GConvLSTM(in_channel => out_channel, 2, g1.num_nodes), Dense(out_channel, 1))
48+
end
49+
3750
@testset "GConvGRUCell" begin
3851
gconvlstm = GraphNeuralNetworks.GConvGRUCell(in_channel => out_channel, 2, g1.num_nodes)
3952
h, h = gconvlstm(gconvlstm.state0, g1, g1.ndata.x)
@@ -55,7 +68,6 @@ end
5568
@test length(Flux.gradient(x ->sum(sum(ginconv(tg, x))), tg.ndata.x)[1]) == S
5669
end
5770

58-
5971
@testset "ChebConv" begin
6072
chebconv = ChebConv(in_channel => out_channel, 5)
6173
@test length(chebconv(tg, tg.ndata.x)) == S

0 commit comments

Comments
 (0)