Skip to content

Commit 763a03d

Browse files
committed
Add GConvGRU docs
1 parent b4c2314 commit 763a03d

File tree

1 file changed

+83
-1
lines changed

1 file changed

+83
-1
lines changed

GNNLux/src/layers/temporalconv.jl

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,49 @@ y, st = tgcn(g, x, ps, st) # result size (6, 5)
103103
"""
104104
TGCN(ch::Pair{Int, Int}; kwargs...) = GNNLux.StatefulRecurrentCell(TGCNCell(ch; kwargs...))
105105

106+
"""
107+
A3TGCN(in => out; use_bias = true, init_weight = glorot_uniform, init_state = zeros32, init_bias = zeros32, add_self_loops = false, use_edge_weight = true)
108+
109+
Attention Temporal Graph Convolutional Network (A3T-GCN) model from the paper [A3T-GCN: Attention Temporal Graph
110+
Convolutional Network for Traffic Forecasting](https://arxiv.org/pdf/2006.11583.pdf).
111+
112+
Performs a TGCN layer, followed by a soft attention layer.
113+
114+
# Arguments
115+
116+
- `in`: Number of input features.
117+
- `out`: Number of output features.
118+
- `use_bias`: Add learnable bias. Default `true`.
119+
- `init_weight`: Weights' initializer. Default `glorot_uniform`.
120+
- `init_state`: Initial state of the hidden stat of the GRU layer. Default `zeros32`.
121+
- `init_bias`: Bias initializer. Default `zeros32`.
122+
- `add_self_loops`: Add self loops to the graph before performing the convolution. Default `false`.
123+
- `use_edge_weight`: If `true`, consider the edge weights in the input graph (if available).
124+
If `add_self_loops=true` the new weights will be set to 1.
125+
This option is ignored if the `edge_weight` is explicitly provided in the forward pass.
126+
Default `false`.
127+
# Examples
128+
129+
```julia
130+
using GNNLux, Lux, Random
131+
132+
# initialize random number generator
133+
rng = Random.default_rng()
134+
135+
# create data
136+
g = rand_graph(rng, 5, 10)
137+
x = rand(rng, Float32, 2, 5)
138+
139+
# create A3TGCN layer
140+
l = A3TGCN(2 => 6)
141+
142+
# setup layer
143+
ps, st = LuxCore.setup(rng, l)
144+
145+
# forward pass
146+
y, st = l(g, x, ps, st) # result size (6, 5)
147+
```
148+
"""
106149
@concrete struct A3TGCN <: GNNContainerLayer{(:tgcn, :dense1, :dense2)}
107150
in_dims::Int
108151
out_dims::Int
@@ -139,7 +182,7 @@ function Base.show(io::IO, l::A3TGCN)
139182
print(io, "A3TGCN($(l.in_dims) => $(l.out_dims))")
140183
end
141184

142-
@concrete struct GConvGRUCell <: GNNContainerLayer{(:conv_x_r, :conv_h_r, :conv_x_z, :conv_h_z, :conv_x_h, :conv_h_h)}
185+
@concrete struct GConvGRUCell <: GNNContainerLayer{(:conv_x_r, :conv_h_r, :conv_x_z, :conv_h_z, :conv_x_h, :conv_h_h)}
143186
in_dims::Int
144187
out_dims::Int
145188
k::Int
@@ -192,6 +235,45 @@ end
192235

193236
LuxCore.outputsize(l::GConvGRUCell) = (l.out_dims,)
194237

238+
"""
239+
GConvGRU(in => out, k; use_bias = true, init_weight = glorot_uniform, init_state = zeros32, init_bias = zeros32)
240+
241+
Graph Convolutional Gated Recurrent Unit (GConvGRU) recurrent layer from the paper [Structured Sequence Modeling with Graph Convolutional Recurrent Networks](https://arxiv.org/pdf/1612.07659).
242+
243+
Performs a layer of ChebConv to model spatial dependencies, followed by a Gated Recurrent Unit (GRU) cell to model temporal dependencies.
244+
245+
# Arguments
246+
247+
- `in`: Number of input features.
248+
- `out`: Number of output features.
249+
- `k`: Chebyshev polynomial order.
250+
- `use_bias`: Add learnable bias. Default `true`.
251+
- `init_weight`: Weights' initializer. Default `glorot_uniform`.
252+
- `init_state`: Initial state of the hidden stat of the GRU layer. Default `zeros32`.
253+
- `init_bias`: Bias initializer. Default `zeros32`.
254+
255+
# Examples
256+
257+
```julia
258+
using GNNLux, Lux, Random
259+
260+
# initialize random number generator
261+
rng = Random.default_rng()
262+
263+
# create data
264+
g = rand_graph(rng, 5, 10)
265+
x = rand(rng, Float32, 2, 5)
266+
267+
# create A3TGCN layer
268+
l = GConvGRU(2 => 5, 2)
269+
270+
# setup layer
271+
ps, st = LuxCore.setup(rng, l)
272+
273+
# forward pass
274+
y, st = l(g, x, ps, st) # result size (5, 5)
275+
```
276+
"""
195277
GConvGRU(ch::Pair{Int, Int}, k::Int; kwargs...) = GNNLux.StatefulRecurrentCell(GConvGRUCell(ch, k; kwargs...))
196278

197279
@concrete struct GConvLSTMCell <: GNNContainerLayer{(:conv_x_i, :conv_h_i, :dense_i, :conv_x_f, :conv_h_f, :dense_f, :conv_x_c, :conv_h_c, :dense_c, :conv_x_o, :conv_h_o, :dense_o)}

0 commit comments

Comments
 (0)