Skip to content

Commit 02919ac

Browse files
GConvLSTM
1 parent 43aedc2 commit 02919ac

File tree

3 files changed

+211
-187
lines changed

3 files changed

+211
-187
lines changed

GraphNeuralNetworks/src/GraphNeuralNetworks.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,9 @@ export HeteroGraphConv
5151

5252
include("layers/temporalconv.jl")
5353
export GConvGRU, GConvGRUCell,
54+
GConvLSTM, GConvLSTMCell,
5455
TGCN,
5556
A3TGCN,
56-
GConvLSTM,
5757
DCGRU,
5858
EvolveGCNO
5959

GraphNeuralNetworks/src/layers/temporalconv.jl

Lines changed: 210 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ julia> size(y) # (d_out, timesteps, num_nodes)
166166
(3, 5, 5)
167167
```
168168
"""
169-
struct GConvGRU{G <: GConvGRUCell} <: GNNLayer
169+
struct GConvGRU{G<:GConvGRUCell} <: GNNLayer
170170
cell::G
171171
end
172172

@@ -186,3 +186,212 @@ function Base.show(io::IO, rnn::GConvGRU)
186186
print(io, "GConvGRU($(rnn.cell.in) => $(rnn.cell.out), $(rnn.cell.k))")
187187
end
188188

189+
190+
"""
191+
GConvLSTMCell(in => out, k; [bias, init])
192+
193+
Graph Convolutional LSTM recurrent cell from the paper
194+
[Structured Sequence Modeling with Graph Convolutional Recurrent Networks](https://arxiv.org/abs/1612.07659).
195+
196+
Uses [`ChebConv`](@ref) to model spatial dependencies,
197+
followed by a Long Short-Term Memory (LSTM) cell to model temporal dependencies.
198+
199+
# Arguments
200+
201+
- `in => out`: A pair where `in` is the number of input node features and `out`
202+
is the number of output node features.
203+
- `k`: Chebyshev polynomial order.
204+
- `bias`: Add learnable bias. Default `true`.
205+
- `init`: Weights' initializer. Default `glorot_uniform`.
206+
207+
# Forward
208+
209+
cell(g::GNNGraph, x, [state])
210+
211+
- `g`: The input graph.
212+
- `x`: The node features. It should be a matrix of size `in x num_nodes`.
213+
- `state`: The initial hidden state of the LSTM cell.
214+
If given, it is a tuple `(h, c)` where both `h` and `c` are arrays of size `out x num_nodes`.
215+
If not provided, the initial hidden state is assumed to be a tuple of matrices of zeros.
216+
217+
Performs one recurrence step and returns a tuple `(output, state)`,
218+
where `output` is the updated hidden state `h` of the LSTM cell and `state` is the updated tuple `(h, c)`.
219+
220+
# Examples
221+
222+
```jldoctest
223+
julia> using GraphNeuralNetworks, Flux
224+
225+
julia> num_nodes, num_edges = 5, 10;
226+
227+
julia> d_in, d_out = 2, 3;
228+
229+
julia> timesteps = 5;
230+
231+
julia> g = rand_graph(num_nodes, num_edges);
232+
233+
julia> x = [rand(Float32, d_in, num_nodes) for t in 1:timesteps];
234+
235+
julia> cell = GConvLSTMCell(d_in => d_out, 2);
236+
237+
julia> state = Flux.initialstates(cell);
238+
239+
julia> y = state[1];
240+
241+
julia> for xt in x
242+
y, state = cell(g, xt, state)
243+
end
244+
245+
julia> size(y) # (d_out, num_nodes)
246+
(3, 5)
247+
```
248+
"""
249+
@concrete struct GConvLSTMCell <: GNNLayer
250+
conv_x_i
251+
conv_h_i
252+
w_i
253+
b_i
254+
conv_x_f
255+
conv_h_f
256+
w_f
257+
b_f
258+
conv_x_c
259+
conv_h_c
260+
w_c
261+
b_c
262+
conv_x_o
263+
conv_h_o
264+
w_o
265+
b_o
266+
k::Int
267+
in::Int
268+
out::Int
269+
end
270+
271+
Flux.@layer GConvLSTMCell
272+
273+
function GConvLSTMCell(ch::Pair{Int, Int}, k::Int;
274+
bias::Bool = true,
275+
init = Flux.glorot_uniform)
276+
in, out = ch
277+
# input gate
278+
conv_x_i = ChebConv(in => out, k; bias, init)
279+
conv_h_i = ChebConv(out => out, k; bias, init)
280+
w_i = init(out, 1)
281+
b_i = bias ? Flux.create_bias(w_i, true, out) : false
282+
# forget gate
283+
conv_x_f = ChebConv(in => out, k; bias, init)
284+
conv_h_f = ChebConv(out => out, k; bias, init)
285+
w_f = init(out, 1)
286+
b_f = bias ? Flux.create_bias(w_f, true, out) : false
287+
# cell state
288+
conv_x_c = ChebConv(in => out, k; bias, init)
289+
conv_h_c = ChebConv(out => out, k; bias, init)
290+
w_c = init(out, 1)
291+
b_c = bias ? Flux.create_bias(w_c, true, out) : false
292+
# output gate
293+
conv_x_o = ChebConv(in => out, k; bias, init)
294+
conv_h_o = ChebConv(out => out, k; bias, init)
295+
w_o = init(out, 1)
296+
b_o = bias ? Flux.create_bias(w_o, true, out) : false
297+
return GConvLSTMCell(conv_x_i, conv_h_i, w_i, b_i,
298+
conv_x_f, conv_h_f, w_f, b_f,
299+
conv_x_c, conv_h_c, w_c, b_c,
300+
conv_x_o, conv_h_o, w_o, b_o,
301+
k, in, out)
302+
end
303+
304+
function Flux.initialstates(cell::GConvLSTMCell)
305+
(zeros_like(cell.conv_x_i.weight, cell.out), zeros_like(cell.conv_x_i.weight, cell.out))
306+
end
307+
308+
function (cell::GConvLSTMCell)(g::GNNGraph, x::AbstractMatrix, (h, c))
309+
if h isa AbstractVector
310+
h = repeat(h, 1, g.num_nodes)
311+
end
312+
if c isa AbstractVector
313+
c = repeat(c, 1, g.num_nodes)
314+
end
315+
@assert ndims(h) == 2 && ndims(c) == 2
316+
# input gate
317+
i = cell.conv_x_i(g, x) .+ cell.conv_h_i(g, h) .+ cell.w_i .* c .+ cell.b_i
318+
i = Flux.sigmoid_fast(i)
319+
# forget gate
320+
f = cell.conv_x_f(g, x) .+ cell.conv_h_f(g, h) .+ cell.w_f .* c .+ cell.b_f
321+
f = Flux.sigmoid_fast(f)
322+
# cell state
323+
c = f .* c .+ i .* Flux.tanh_fast(cell.conv_x_c(g, x) .+ cell.conv_h_c(g, h) .+ cell.w_c .* c .+ cell.b_c)
324+
# output gate
325+
o = cell.conv_x_o(g, x) .+ cell.conv_h_o(g, h) .+ cell.w_o .* c .+ cell.b_o
326+
o = Flux.sigmoid_fast(o)
327+
h = o .* Flux.tanh_fast(c)
328+
return h, (h, c)
329+
end
330+
331+
function Base.show(io::IO, cell::GConvLSTMCell)
332+
print(io, "GConvLSTMCell($(cell.in) => $(cell.out), $(cell.k))")
333+
end
334+
335+
336+
"""
337+
GConvLSTM(in => out, k; kws...)
338+
339+
The recurrent layer corresponding to the [`GConvLSTMCell`](@ref) cell,
340+
used to process an entire temporal sequence of node features at once.
341+
342+
The arguments are the same as for [`GConvLSTMCell`](@ref).
343+
344+
# Forward
345+
346+
layer(g::GNNGraph, x, [state])
347+
348+
- `g`: The input graph.
349+
- `x`: The time-varying node features. It should be an array of size `in x timesteps x num_nodes`.
350+
- `state`: The initial hidden state of the LSTM cell.
351+
If given, it is a tuple `(h, c)` where both elements are matrices of size `out x num_nodes`.
352+
If not provided, the initial hidden state is assumed to be a tuple of matrices of zeros.
353+
354+
Applies the recurrent cell to each timestep of the input sequence and returns the output as
355+
an array of size `out x timesteps x num_nodes`.
356+
357+
# Examples
358+
359+
```jldoctest
360+
julia> num_nodes, num_edges = 5, 10;
361+
362+
julia> d_in, d_out = 2, 3;
363+
364+
julia> timesteps = 5;
365+
366+
julia> g = rand_graph(num_nodes, num_edges);
367+
368+
julia> x = rand(Float32, d_in, timesteps, num_nodes);
369+
370+
julia> layer = GConvLSTM(d_in => d_out, 2);
371+
372+
julia> y = layer(g, x);
373+
374+
julia> size(y) # (d_out, timesteps, num_nodes)
375+
(3, 5, 5)
376+
```
377+
"""
378+
struct GConvLSTM{G<:GConvLSTMCell} <: GNNLayer
379+
cell::G
380+
end
381+
382+
Flux.@layer GConvLSTM
383+
384+
function GConvLSTM(ch::Pair{Int,Int}, k::Int; kws...)
385+
return GConvLSTM(GConvLSTMCell(ch, k; kws...))
386+
end
387+
388+
Flux.initialstates(rnn::GConvLSTM) = Flux.initialstates(rnn.cell)
389+
390+
function (rnn::GConvLSTM)(g::GNNGraph, x::AbstractArray)
391+
return scan(rnn.cell, g, x, initialstates(rnn))
392+
end
393+
394+
function Base.show(io::IO, rnn::GConvLSTM)
395+
print(io, "GConvLSTM($(rnn.cell.in) => $(rnn.cell.out), $(rnn.cell.k))")
396+
end
397+

0 commit comments

Comments
 (0)