Skip to content

Commit cfd9ec3

Browse files
GNNRecurrence
1 parent 02919ac commit cfd9ec3

File tree

3 files changed

+220
-195
lines changed

3 files changed

+220
-195
lines changed

GraphNeuralNetworks/src/GraphNeuralNetworks.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,12 @@ include("layers/heteroconv.jl")
5050
export HeteroGraphConv
5151

5252
include("layers/temporalconv.jl")
53-
export GConvGRU, GConvGRUCell,
53+
export GNNRecurrence,
54+
GConvGRU, GConvGRUCell,
5455
GConvLSTM, GConvLSTMCell,
56+
DCGRU, DCGRUCell,
5557
TGCN,
5658
A3TGCN,
57-
DCGRU,
5859
EvolveGCNO
5960

6061
include("layers/pool.jl")

GraphNeuralNetworks/src/layers/temporalconv.jl

Lines changed: 217 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,73 @@ function scan(cell, g::GNNGraph, x::AbstractArray{T,3}, state) where {T}
77
return stack(y, dims = 2)
88
end
99

10+
11+
"""
12+
GNNRecurrence(cell)
13+
14+
Construct a recurrent layer that applies the `cell`
15+
to process an entire temporal sequence of node features at once.
16+
17+
# Forward
18+
19+
layer(g::GNNGraph, x, [state])
20+
21+
- `g`: The input graph.
22+
- `x`: The time-varying node features. An array of size `in x timesteps x num_nodes`.
23+
- `state`: The initial state of the cell.
24+
If not provided, it is generated by calling `Flux.initialstates(cell)`.
25+
26+
Applies the recurrent cell to each timestep of the input sequence and returns the output as
27+
an array of size `out_features x timesteps x num_nodes`.
28+
29+
# Examples
30+
31+
```jldoctest
32+
julia> num_nodes, num_edges = 5, 10;
33+
34+
julia> d_in, d_out = 2, 3;
35+
36+
julia> timesteps = 5;
37+
38+
julia> g = rand_graph(num_nodes, num_edges);
39+
40+
julia> x = rand(Float32, d_in, timesteps, num_nodes);
41+
42+
julia> cell = GConvLSTMCell(d_in => d_out, 2)
43+
GConvLSTMCell(2 => 3, 2) # 168 parameters
44+
45+
julia> layer = GNNRecurrence(cell)
46+
GNNRecurrence(
47+
GConvLSTMCell(2 => 3, 2), # 168 parameters
48+
) # Total: 24 arrays, 168 parameters, 2.023 KiB.
49+
50+
julia> y = layer(g, x);
51+
52+
julia> size(y) # (d_out, timesteps, num_nodes)
53+
(3, 5, 5)
54+
```
55+
"""
56+
struct GNNRecurrence{G} <: GNNLayer
57+
cell::G
58+
end
59+
60+
Flux.@layer GNNRecurrence
61+
62+
Flux.initialstates(rnn::GNNRecurrence) = Flux.initialstates(rnn.cell)
63+
64+
function (rnn::GNNRecurrence)(g::GNNGraph, x::AbstractArray{T,3}) where {T}
65+
return rnn(g, x, initialstates(rnn))
66+
end
67+
68+
function (rnn::GNNRecurrence)(g::GNNGraph, x::AbstractArray{T,3}, state) where {T}
69+
return scan(rnn.cell, g, x, state)
70+
end
71+
72+
function Base.show(io::IO, rnn::GNNRecurrence)
73+
print(io, "GNNRecurrence($(rnn.cell))")
74+
end
75+
76+
1077
"""
1178
GConvGRUCell(in => out, k; [bias, init])
1279
@@ -126,24 +193,13 @@ function Base.show(io::IO, cell::GConvGRUCell)
126193
end
127194

128195
"""
129-
GConvGRU(in => out, k; kws...)
130-
131-
The recurrent layer corresponding to the [`GConvGRUCell`](@ref) cell,
132-
used to process an entire temporal sequence of node features at once.
196+
GConvGRU(args...; kws...)
133197
134-
The arguments are the same as for [`GConvGRUCell`](@ref).
198+
Construct a recurrent layer corresponding to the [`GConvGRUCell`](@ref) cell.
199+
It can be used to process an entire temporal sequence of node features at once.
135200
136-
# Forward
137-
138-
layer(g::GNNGraph, x, [h])
139-
140-
- `g`: The input graph.
141-
- `x`: The time-varying node features. It should be an array of size `in x timesteps x num_nodes`.
142-
- `h`: The initial hidden state of the GRU cell. If given, it is a matrix of size `out x num_nodes`.
143-
If not provided, it is assumed to be a matrix of zeros.
144-
145-
Applies the recurrent cell to each timestep of the input sequence and returns the output as
146-
an array of size `out x timesteps x num_nodes`.
201+
The arguments are passed to the [`GConvGRUCell`](@ref) constructor.
202+
See [`GNNRecurrence`](@ref) for more details.
147203
148204
# Examples
149205
@@ -158,33 +214,18 @@ julia> g = rand_graph(num_nodes, num_edges);
158214
159215
julia> x = rand(Float32, d_in, timesteps, num_nodes);
160216
161-
julia> layer = GConvGRU(d_in => d_out, 2);
217+
julia> layer = GConvGRU(d_in => d_out, 2)
218+
GConvGRU(
219+
GConvGRUCell(2 => 3, 2), # 108 parameters
220+
) # Total: 12 arrays, 108 parameters, 1.148 KiB.
162221
163222
julia> y = layer(g, x);
164223
165224
julia> size(y) # (d_out, timesteps, num_nodes)
166225
(3, 5, 5)
167226
```
168227
"""
169-
struct GConvGRU{G<:GConvGRUCell} <: GNNLayer
170-
cell::G
171-
end
172-
173-
Flux.@layer GConvGRU
174-
175-
function GConvGRU(ch::Pair{Int,Int}, k::Int; kws...)
176-
return GConvGRU(GConvGRUCell(ch, k; kws...))
177-
end
178-
179-
Flux.initialstates(rnn::GConvGRU) = Flux.initialstates(rnn.cell)
180-
181-
function (rnn::GConvGRU)(g::GNNGraph, x::AbstractArray)
182-
return scan(rnn.cell, g, x, initialstates(rnn))
183-
end
184-
185-
function Base.show(io::IO, rnn::GConvGRU)
186-
print(io, "GConvGRU($(rnn.cell.in) => $(rnn.cell.out), $(rnn.cell.k))")
187-
end
228+
GConvGRU(args...; kws...) = GNNRecurrence(GConvGRUCell(args...; kws...))
188229

189230

190231
"""
@@ -268,7 +309,7 @@ julia> size(y) # (d_out, num_nodes)
268309
out::Int
269310
end
270311

271-
Flux.@layer GConvLSTMCell
312+
Flux.@layer :noexpand GConvLSTMCell
272313

273314
function GConvLSTMCell(ch::Pair{Int, Int}, k::Int;
274315
bias::Bool = true,
@@ -305,6 +346,8 @@ function Flux.initialstates(cell::GConvLSTMCell)
305346
(zeros_like(cell.conv_x_i.weight, cell.out), zeros_like(cell.conv_x_i.weight, cell.out))
306347
end
307348

349+
(cell::GConvLSTMCell)(g::GNNGraph, x::AbstractMatrix) = cell(g, x, initialstates(cell))
350+
308351
function (cell::GConvLSTMCell)(g::GNNGraph, x::AbstractMatrix, (h, c))
309352
if h isa AbstractVector
310353
h = repeat(h, 1, g.num_nodes)
@@ -334,29 +377,74 @@ end
334377

335378

336379
"""
337-
GConvLSTM(in => out, k; kws...)
380+
GConvLSTM(args...; kws...)
381+
382+
Construct a recurrent layer corresponding to the [`GConvLSTMCell`](@ref) cell.
383+
It can be used to process an entire temporal sequence of node features at once.
384+
385+
The arguments are passed to the [`GConvLSTMCell`](@ref) constructor.
386+
See [`GNNRecurrence`](@ref) for more details.
387+
388+
# Examples
389+
390+
```jldoctest
391+
julia> num_nodes, num_edges = 5, 10;
392+
393+
julia> d_in, d_out = 2, 3;
394+
395+
julia> timesteps = 5;
396+
397+
julia> g = rand_graph(num_nodes, num_edges);
398+
399+
julia> x = rand(Float32, d_in, timesteps, num_nodes);
400+
401+
julia> layer = GConvLSTM(d_in => d_out, 2)
402+
GNNRecurrence(
403+
GConvLSTMCell(2 => 3, 2), # 168 parameters
404+
) # Total: 24 arrays, 168 parameters, 2.023 KiB.
405+
406+
julia> y = layer(g, x);
407+
408+
julia> size(y) # (d_out, timesteps, num_nodes)
409+
(3, 5, 5)
410+
```
411+
"""
412+
GConvLSTM(args...; kws...) = GNNRecurrence(GConvLSTMCell(args...; kws...))
413+
414+
"""
415+
DCGRUCell(in => out, k; [bias, init])
416+
417+
Diffusion Convolutional Recurrent Neural Network (DCGRU) cell from the paper
418+
[Diffusion Convolutional Recurrent Neural Network: Data-driven Traffic Forecasting](https://arxiv.org/abs/1707.01926).
419+
420+
Applyis a [`DConv`](@ref) layer to model spatial dependencies,
421+
in combination with a Gated Recurrent Unit (GRU) cell to model temporal dependencies.
338422
339-
The recurrent layer corresponding to the [`GConvLSTMCell`](@ref) cell,
340-
used to process an entire temporal sequence of node features at once.
423+
# Arguments
341424
342-
The arguments are the same as for [`GConvLSTMCell`](@ref).
425+
- `in`: Number of input node features.
426+
- `out`: Number of output node features.
427+
- `k`: Diffusion step for the `DConv`.
428+
- `bias`: Add learnable bias. Default `true`.
429+
- `init`: Weights' initializer. Default `glorot_uniform`.
343430
344431
# Forward
345432
346-
layer(g::GNNGraph, x, [state])
433+
cell(g::GNNGraph, x, [h])
347434
348435
- `g`: The input graph.
349-
- `x`: The time-varying node features. It should be an array of size `in x timesteps x num_nodes`.
350-
- `state`: The initial hidden state of the LSTM cell.
351-
If given, it is a tuple `(h, c)` where both elements are matrices of size `out x num_nodes`.
352-
If not provided, the initial hidden state is assumed to be a tuple of matrices of zeros.
436+
- `x`: The node features. It should be a matrix of size `in x num_nodes`.
437+
- `h`: The initial hidden state of the GRU cell. If given, it is a matrix of size `out x num_nodes`.
438+
If not provided, it is assumed to be a matrix of zeros.
353439
354-
Applies the recurrent cell to each timestep of the input sequence and returns the output as
355-
an array of size `out x timesteps x num_nodes`.
440+
Performs one recurrence step and returns a tuple `(h, h)`,
441+
where `h` is the updated hidden state of the GRU cell.
356442
357443
# Examples
358444
359445
```jldoctest
446+
julia> using GraphNeuralNetworks, Flux
447+
360448
julia> num_nodes, num_edges = 5, 10;
361449
362450
julia> d_in, d_out = 2, 3;
@@ -365,33 +453,98 @@ julia> timesteps = 5;
365453
366454
julia> g = rand_graph(num_nodes, num_edges);
367455
368-
julia> x = rand(Float32, d_in, timesteps, num_nodes);
456+
julia> x = [rand(Float32, d_in, num_nodes) for t in 1:timesteps];
369457
370-
julia> layer = GConvLSTM(d_in => d_out, 2);
458+
julia> cell = DCGRUCell(d_in => d_out, 2);
371459
372-
julia> y = layer(g, x);
460+
julia> state = Flux.initialstates(cell);
373461
374-
julia> size(y) # (d_out, timesteps, num_nodes)
375-
(3, 5, 5)
462+
julia> y = state;
463+
464+
julia> for xt in x
465+
y, state = cell(g, xt, state)
466+
end
467+
468+
julia> size(y) # (d_out, num_nodes)
469+
(3, 5)
376470
```
377-
"""
378-
struct GConvLSTM{G<:GConvLSTMCell} <: GNNLayer
379-
cell::G
471+
"""
472+
struct DCGRUCell
473+
in::Int
474+
out::Int
475+
k::Int
476+
dconv_u::DConv
477+
dconv_r::DConv
478+
dconv_c::DConv
380479
end
381480

382-
Flux.@layer GConvLSTM
481+
Flux.@layer :noexpand DCGRUCell
383482

384-
function GConvLSTM(ch::Pair{Int,Int}, k::Int; kws...)
385-
return GConvLSTM(GConvLSTMCell(ch, k; kws...))
483+
function DCGRUCell(ch::Pair{Int,Int}, k::Int; bias = true, init = glorot_uniform)
484+
in, out = ch
485+
dconv_u = DConv((in + out) => out, k; bias, init)
486+
dconv_r = DConv((in + out) => out, k; bias, init)
487+
dconv_c = DConv((in + out) => out, k; bias, init)
488+
return DCGRUCell(in, out, k, dconv_u, dconv_r, dconv_c)
386489
end
387490

388-
Flux.initialstates(rnn::GConvLSTM) = Flux.initialstates(rnn.cell)
491+
Flux.initialstates(cell::DCGRUCell) = zeros_like(cell.dconv_u.weights, cell.out)
492+
493+
(cell::DCGRUCell)(g::GNNGraph, x::AbstractMatrix) = cell(g, x, initialstates(cell))
389494

390-
function (rnn::GConvLSTM)(g::GNNGraph, x::AbstractArray)
391-
return scan(rnn.cell, g, x, initialstates(rnn))
495+
function (cell::DCGRUCell)(g::GNNGraph, x::AbstractMatrix, h::AbstractVector)
496+
return cell(g, x, repeat(h, 1, g.num_nodes))
392497
end
393498

394-
function Base.show(io::IO, rnn::GConvLSTM)
395-
print(io, "GConvLSTM($(rnn.cell.in) => $(rnn.cell.out), $(rnn.cell.k))")
499+
function (cell::DCGRUCell)(g::GNNGraph, x::AbstractMatrix, h::AbstractMatrix)
500+
= vcat(x, h)
501+
z = cell.dconv_u(g, h̃)
502+
z = NNlib.sigmoid_fast.(z)
503+
r = cell.dconv_r(g, h̃)
504+
r = NNlib.sigmoid_fast.(r)
505+
= vcat(x, h .* r)
506+
c = cell.dconv_c(g, ĥ)
507+
c = NNlib.tanh_fast.(c)
508+
h = z.* h + (1 .- z) .* c
509+
return h, h
510+
end
511+
512+
function Base.show(io::IO, cell::DCGRUCell)
513+
print(io, "DCGRUCell($(cell.in) => $(cell.out), $(cell.k))")
396514
end
397515

516+
"""
517+
DCGRU(args...; kws...)
518+
519+
Construct a recurrent layer corresponding to the [`DCGRUCell`](@ref) cell.
520+
It can be used to process an entire temporal sequence of node features at once.
521+
522+
The arguments are passed to the [`DCGRUCell`](@ref) constructor.
523+
See [`GNNRecurrence`](@ref) for more details.
524+
525+
# Examples
526+
```jldoctest
527+
julia> num_nodes, num_edges = 5, 10;
528+
529+
julia> d_in, d_out = 2, 3;
530+
531+
julia> timesteps = 5;
532+
533+
julia> g = rand_graph(num_nodes, num_edges);
534+
535+
julia> x = rand(Float32, d_in, timesteps, num_nodes);
536+
537+
julia> layer = DCGRU(d_in => d_out, 2)
538+
GNNRecurrence(
539+
DCGRUCell(2 => 3, 2), # 189 parameters
540+
) # Total: 6 arrays, 189 parameters, 1.184 KiB.
541+
542+
julia> y = layer(g, x);
543+
544+
julia> size(y) # (d_out, timesteps, num_nodes)
545+
(3, 5, 5)
546+
```
547+
"""
548+
DCGRU(args...; kws...) = GNNRecurrence(DCGRUCell(args...; kws...))
549+
550+

0 commit comments

Comments
 (0)