Skip to content

Commit cafc1bc

Browse files
Add DCGRU temporal layer (#448)
* Add `DCGRU` code * Add `DCGRU` tests * Add export * Add docs * Update src/layers/temporalconv.jl --------- Co-authored-by: Carlo Lucibello <[email protected]>
1 parent f6b95fc commit cafc1bc

File tree

3 files changed

+92
-0
lines changed

3 files changed

+92
-0
lines changed

src/GraphNeuralNetworks.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ export
7575
A3TGCN,
7676
GConvLSTM,
7777
GConvGRU,
78+
DCGRU,
7879

7980
# layers/pool
8081
GlobalPool,

src/layers/temporalconv.jl

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,89 @@ Flux.Recur(tgcn::GConvLSTMCell) = Flux.Recur(tgcn, tgcn.state0)
401401
_applylayer(l::Flux.Recur{GConvLSTMCell}, g::GNNGraph, x) = l(g, x)
402402
_applylayer(l::Flux.Recur{GConvLSTMCell}, g::GNNGraph) = l(g)
403403

404+
struct DCGRUCell
405+
in::Int
406+
out::Int
407+
state0
408+
k::Int
409+
dconv_u::DConv
410+
dconv_r::DConv
411+
dconv_c::DConv
412+
end
413+
414+
Flux.@functor DCGRUCell
415+
416+
function DCGRUCell(ch::Pair{Int,Int}, k::Int, n::Int; bias = true, init = glorot_uniform, init_state = Flux.zeros32)
417+
in, out = ch
418+
dconv_u = DConv((in + out) => out, k; bias=bias, init=init)
419+
dconv_r = DConv((in + out) => out, k; bias=bias, init=init)
420+
dconv_c = DConv((in + out) => out, k; bias=bias, init=init)
421+
state0 = init_state(out, n)
422+
return DCGRUCell(in, out, state0, k, dconv_u, dconv_r, dconv_c)
423+
end
424+
425+
function (dcgru::DCGRUCell)(h, g::GNNGraph, x)
426+
= vcat(x, h)
427+
z = dcgru.dconv_u(g, h̃)
428+
z = NNlib.sigmoid_fast.(z)
429+
r = dcgru.dconv_r(g, h̃)
430+
r = NNlib.sigmoid_fast.(r)
431+
= vcat(x, h .* r)
432+
c = dcgru.dconv_c(g, ĥ)
433+
c = tanh.(c)
434+
h = z.* h + (1 .- z) .* c
435+
return h, h
436+
end
437+
438+
function Base.show(io::IO, dcgru::DCGRUCell)
439+
print(io, "DCGRUCell($(dcgru.in) => $(dcgru.out), $(dcgru.k))")
440+
end
441+
442+
"""
443+
DCGRU(in => out, k, n; [bias, init, init_state])
444+
445+
Diffusion Convolutional Recurrent Neural Network (DCGRU) layer from the paper [Diffusion Convolutional Recurrent Neural
446+
Network: Data-driven Traffic Forecasting](https://arxiv.org/pdf/1707.01926).
447+
448+
Performs a Diffusion Convolutional layer to model spatial dependencies, followed by a Gated Recurrent Unit (GRU) cell to model temporal dependencies.
449+
450+
# Arguments
451+
452+
- `in`: Number of input features.
453+
- `out`: Number of output features.
454+
- `k`: Diffusion step.
455+
- `n`: Number of nodes in the graph.
456+
- `bias`: Add learnable bias. Default `true`.
457+
- `init`: Weights' initializer. Default `glorot_uniform`.
458+
- `init_state`: Initial state of the hidden stat of the LSTM layer. Default `zeros32`.
459+
460+
# Examples
461+
462+
```jldoctest
463+
julia> g1, x1 = rand_graph(5, 10), rand(Float32, 2, 5);
464+
465+
julia> dcgru = DCGRU(2 => 5, 2, g1.num_nodes);
466+
467+
julia> y = dcgru(g1, x1);
468+
469+
julia> size(y)
470+
(5, 5)
471+
472+
julia> g2, x2 = rand_graph(5, 10), rand(Float32, 2, 5, 30);
473+
474+
julia> z = dcgru(g2, x2);
475+
476+
julia> size(z)
477+
(5, 5, 30)
478+
```
479+
"""
480+
DCGRU(ch, k, n; kwargs...) = Flux.Recur(DCGRUCell(ch, k, n; kwargs...))
481+
Flux.Recur(dcgru::DCGRUCell) = Flux.Recur(dcgru, dcgru.state0)
482+
483+
(l::Flux.Recur{DCGRUCell})(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g)))
484+
_applylayer(l::Flux.Recur{DCGRUCell}, g::GNNGraph, x) = l(g, x)
485+
_applylayer(l::Flux.Recur{DCGRUCell}, g::GNNGraph) = l(g)
486+
404487
function (l::GINConv)(tg::TemporalSnapshotsGNNGraph, x::AbstractVector)
405488
return l.(tg.snapshots, x)
406489
end

test/layers/temporalconv.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,14 @@ end
6161
@test model(g1) isa GNNGraph
6262
end
6363

64+
@testset "DCGRU" begin
65+
dcgru = DCGRU(in_channel => out_channel, 2, g1.num_nodes)
66+
@test size(Flux.gradient(x -> sum(dcgru(g1, x)), g1.ndata.x)[1]) == (in_channel, N)
67+
model = GNNChain(DCGRU(in_channel => out_channel, 2, g1.num_nodes), Dense(out_channel, 1))
68+
@test size(model(g1, g1.ndata.x)) == (1, N)
69+
@test model(g1) isa GNNGraph
70+
end
71+
6472
@testset "GINConv" begin
6573
ginconv = GINConv(Dense(in_channel => out_channel),0.3)
6674
@test length(ginconv(tg, tg.ndata.x)) == S

0 commit comments

Comments
 (0)