Skip to content

Commit 8036671

Browse files
committed
Add docs
1 parent e0c0cc2 commit 8036671

File tree

1 file changed

+46
-8
lines changed

1 file changed

+46
-8
lines changed

src/layers/temporalconv.jl

Lines changed: 46 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -405,21 +405,21 @@ struct DCGRUCell
405405
in::Int
406406
out::Int
407407
state0
408-
K::Int
408+
k::Int
409409
dconv_u::DConv
410410
dconv_r::DConv
411411
dconv_c::DConv
412412
end
413413

414414
Flux.@functor DCGRUCell
415415

416-
function DCGRUCell(ch::Pair{Int,Int}, K::Int, n::Int; bias = true, init = glorot_uniform, init_state = Flux.zeros32)
416+
function DCGRUCell(ch::Pair{Int,Int}, k::Int, n::Int; bias = true, init = glorot_uniform, init_state = Flux.zeros32)
417417
in, out = ch
418-
dconv_u = DConv((in + out) => out, K; bias=bias, init=init)
419-
dconv_r = DConv((in + out) => out, K; bias=bias, init=init)
420-
dconv_c = DConv((in + out) => out, K; bias=bias, init=init)
418+
dconv_u = DConv((in + out) => out, k; bias=bias, init=init)
419+
dconv_r = DConv((in + out) => out, k; bias=bias, init=init)
420+
dconv_c = DConv((in + out) => out, k; bias=bias, init=init)
421421
state0 = init_state(out, n)
422-
return DCGRUCell(in, out, state0, K, dconv_u, dconv_r, dconv_c)
422+
return DCGRUCell(in, out, state0, k, dconv_u, dconv_r, dconv_c)
423423
end
424424

425425
function (dcgru::DCGRUCell)(h, g::GNNGraph, x)
@@ -436,10 +436,48 @@ function (dcgru::DCGRUCell)(h, g::GNNGraph, x)
436436
end
437437

438438
function Base.show(io::IO, dcgru::DCGRUCell)
439-
print(io, "DCGRUCell($(dcgru.in) => $(dcgru.out), $(dcgru.K))")
439+
print(io, "DCGRUCell($(dcgru.in) => $(dcgru.out), $(dcgru.k))")
440440
end
441441

442-
DCGRU(ch, K, n; kwargs...) = Flux.Recur(DCGRUCell(ch, K, n; kwargs...))
442+
"""
443+
DCGRU(in => out, k, n; [bias, init, init_state])
444+
445+
Diffusion Convolutional Recurrent Neural Network (DCGRU) layer from the paper [Diffusion Convolutional Recurrent Neural
446+
Network: Data-driven Traffic Forecasting](https://arxiv.org/pdf/1707.01926).
447+
448+
Performs a Diffusion Convolutional layer to model spatial dependencies, followed by a Gated Recurrent Unit (GRU) cell to model temporal dependencies.
449+
450+
# Arguments
451+
452+
- `in`: Number of input features.
453+
- `out`: Number of output features.
454+
- `k`: Diffusion step.
455+
- `n`: Number of nodes in the graph.
456+
- `bias`: Add learnable bias. Default `true`.
457+
- `init`: Weights' initializer. Default `glorot_uniform`.
458+
- `init_state`: Initial state of the hidden stat of the LSTM layer. Default `zeros32`.
459+
460+
# Examples
461+
462+
```jldoctest
463+
julia> g1, x1 = rand_graph(5, 10), rand(Float32, 2, 5);
464+
465+
julia> dcgru = DCGRU(2 => 5, 2, g1.num_nodes);
466+
467+
julia> y = dcgru(g1, x1);
468+
469+
julia> size(y)
470+
(5, 5)
471+
472+
julia> g2, x2 = rand_graph(5, 10), rand(Float32, 2, 5, 30);
473+
474+
julia> z = dcgru(g2, x2);
475+
476+
julia> size(z)
477+
(5, 5, 30)
478+
```
479+
"""
480+
DCGRU(ch, k, n; kwargs...) = Flux.Recur(DCGRUCell(ch, k, n; kwargs...))
443481
Flux.Recur(dcgru::DCGRUCell) = Flux.Recur(dcgru, dcgru.state0)
444482

445483
(l::Flux.Recur{DCGRUCell})(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g)))

0 commit comments

Comments
 (0)