Skip to content

Commit 44704b0

Browse files
committed
Add GConvLASTM docs
1 parent 763a03d commit 44704b0

File tree

1 file changed

+40
-1
lines changed

1 file changed

+40
-1
lines changed

GNNLux/src/layers/temporalconv.jl

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ rng = Random.default_rng()
264264
g = rand_graph(rng, 5, 10)
265265
x = rand(rng, Float32, 2, 5)
266266
267-
# create A3TGCN layer
267+
# create layer
268268
l = GConvGRU(2 => 5, 2)
269269
270270
# setup layer
@@ -357,6 +357,45 @@ end
357357

358358
LuxCore.outputsize(l::GConvLSTMCell) = (l.out_dims,)
359359

360+
"""
361+
GConvLSTM(in => out, k; use_bias = true, init_weight = glorot_uniform, init_state = zeros32, init_bias = zeros32)
362+
363+
Graph Convolutional Long Short-Term Memory (GConvLSTM) recurrent layer from the paper [Structured Sequence Modeling with Graph Convolutional Recurrent Networks](https://arxiv.org/pdf/1612.07659).
364+
365+
Performs a layer of ChebConv to model spatial dependencies, followed by a Long Short-Term Memory (LSTM) cell to model temporal dependencies.
366+
367+
# Arguments
368+
369+
- `in`: Number of input features.
370+
- `out`: Number of output features.
371+
- `k`: Chebyshev polynomial order.
372+
- `use_bias`: Add learnable bias. Default `true`.
373+
- `init_weight`: Weights' initializer. Default `glorot_uniform`.
374+
- `init_state`: Initial state of the hidden stat of the GRU layer. Default `zeros32`.
375+
- `init_bias`: Bias initializer. Default `zeros32`.
376+
377+
# Examples
378+
379+
```julia
380+
using GNNLux, Lux, Random
381+
382+
# initialize random number generator
383+
rng = Random.default_rng()
384+
385+
# create data
386+
g = rand_graph(rng, 5, 10)
387+
x = rand(rng, Float32, 2, 5)
388+
389+
# create GConvLSTM layer
390+
l = GConvLSTM(2 => 5, 2)
391+
392+
# setup layer
393+
ps, st = LuxCore.setup(rng, l)
394+
395+
# forward pass
396+
y, st = l(g, x, ps, st) # result size (5, 5)
397+
```
398+
"""
360399
GConvLSTM(ch::Pair{Int, Int}, k::Int; kwargs...) = GNNLux.StatefulRecurrentCell(GConvLSTMCell(ch, k; kwargs...))
361400

362401
@concrete struct DCGRUCell <: GNNContainerLayer{(:dconv_u, :dconv_r, :dconv_c)}

0 commit comments

Comments
 (0)