Skip to content

Commit 8c42880

Browse files
committed
Add DCGRU docs
1 parent 44704b0 commit 8c42880

File tree

1 file changed

+40
-0
lines changed

1 file changed

+40
-0
lines changed

GNNLux/src/layers/temporalconv.jl

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,46 @@ end
438438

439439
LuxCore.outputsize(l::DCGRUCell) = (l.out_dims,)
440440

441+
"""
442+
DCGRU(in => out, k; use_bias = true, init_weight = glorot_uniform, init_state = zeros32, init_bias = zeros32)
443+
444+
Diffusion Convolutional Recurrent Neural Network (DCGRU) layer from the paper [Diffusion Convolutional Recurrent Neural
445+
Network: Data-driven Traffic Forecasting](https://arxiv.org/pdf/1707.01926).
446+
447+
Performs a Diffusion Convolutional layer to model spatial dependencies, followed by a Gated Recurrent Unit (GRU) cell to model temporal dependencies.
448+
449+
# Arguments
450+
451+
- `in`: Number of input features.
452+
- `out`: Number of output features.
453+
- `k`: Diffusion step.
454+
- `use_bias`: Add learnable bias. Default `true`.
455+
- `init_weight`: Weights' initializer. Default `glorot_uniform`.
456+
- `init_state`: Initial state of the hidden stat of the GRU layer. Default `zeros32`.
457+
- `init_bias`: Bias initializer. Default `zeros32`.
458+
459+
# Examples
460+
461+
```julia
462+
using GNNLux, Lux, Random
463+
464+
# initialize random number generator
465+
rng = Random.default_rng()
466+
467+
# create data
468+
g = rand_graph(rng, 5, 10)
469+
x = rand(rng, Float32, 2, 5)
470+
471+
# create layer
472+
l = DCGRU(2 => 5, 2)
473+
474+
# setup layer
475+
ps, st = LuxCore.setup(rng, l)
476+
477+
# forward pass
478+
y, st = l(g, x, ps, st) # result size (5, 5)
479+
```
480+
"""
441481
DCGRU(ch::Pair{Int, Int}, k::Int; kwargs...) = GNNLux.StatefulRecurrentCell(DCGRUCell(ch, k; kwargs...))
442482

443483
@concrete struct EvolveGCNO <: GNNLayer

0 commit comments

Comments
 (0)