@@ -103,6 +103,49 @@ y, st = tgcn(g, x, ps, st) # result size (6, 5)
103103"""
104104TGCN (ch:: Pair{Int, Int} ; kwargs... ) = GNNLux. StatefulRecurrentCell (TGCNCell (ch; kwargs... ))
105105
106+ """
107+ A3TGCN(in => out; use_bias = true, init_weight = glorot_uniform, init_state = zeros32, init_bias = zeros32, add_self_loops = false, use_edge_weight = true)
108+
109+ Attention Temporal Graph Convolutional Network (A3T-GCN) model from the paper [A3T-GCN: Attention Temporal Graph
110+ Convolutional Network for Traffic Forecasting](https://arxiv.org/pdf/2006.11583.pdf).
111+
112+ Performs a TGCN layer, followed by a soft attention layer.
113+
114+ # Arguments
115+
116+ - `in`: Number of input features.
117+ - `out`: Number of output features.
118+ - `use_bias`: Add learnable bias. Default `true`.
119+ - `init_weight`: Weights' initializer. Default `glorot_uniform`.
120+ - `init_state`: Initial state of the hidden stat of the GRU layer. Default `zeros32`.
121+ - `init_bias`: Bias initializer. Default `zeros32`.
122+ - `add_self_loops`: Add self loops to the graph before performing the convolution. Default `false`.
123+ - `use_edge_weight`: If `true`, consider the edge weights in the input graph (if available).
124+ If `add_self_loops=true` the new weights will be set to 1.
125+ This option is ignored if the `edge_weight` is explicitly provided in the forward pass.
126+ Default `false`.
127+ # Examples
128+
129+ ```julia
130+ using GNNLux, Lux, Random
131+
132+ # initialize random number generator
133+ rng = Random.default_rng()
134+
135+ # create data
136+ g = rand_graph(rng, 5, 10)
137+ x = rand(rng, Float32, 2, 5)
138+
139+ # create A3TGCN layer
140+ l = A3TGCN(2 => 6)
141+
142+ # setup layer
143+ ps, st = LuxCore.setup(rng, l)
144+
145+ # forward pass
146+ y, st = l(g, x, ps, st) # result size (6, 5)
147+ ```
148+ """
106149@concrete struct A3TGCN <: GNNContainerLayer{(:tgcn, :dense1, :dense2)}
107150 in_dims:: Int
108151 out_dims:: Int
@@ -139,7 +182,7 @@ function Base.show(io::IO, l::A3TGCN)
139182 print (io, " A3TGCN($(l. in_dims) => $(l. out_dims) )" )
140183end
141184
142- @concrete struct GConvGRUCell <: GNNContainerLayer{(:conv_x_r, :conv_h_r, :conv_x_z, :conv_h_z, :conv_x_h, :conv_h_h)}
185+ @concrete struct GConvGRUCell <: GNNContainerLayer{(:conv_x_r, :conv_h_r, :conv_x_z, :conv_h_z, :conv_x_h, :conv_h_h)}
143186 in_dims:: Int
144187 out_dims:: Int
145188 k:: Int
192235
193236LuxCore. outputsize (l:: GConvGRUCell ) = (l. out_dims,)
194237
238+ """
239+ GConvGRU(in => out, k; use_bias = true, init_weight = glorot_uniform, init_state = zeros32, init_bias = zeros32)
240+
241+ Graph Convolutional Gated Recurrent Unit (GConvGRU) recurrent layer from the paper [Structured Sequence Modeling with Graph Convolutional Recurrent Networks](https://arxiv.org/pdf/1612.07659).
242+
243+ Performs a layer of ChebConv to model spatial dependencies, followed by a Gated Recurrent Unit (GRU) cell to model temporal dependencies.
244+
245+ # Arguments
246+
247+ - `in`: Number of input features.
248+ - `out`: Number of output features.
249+ - `k`: Chebyshev polynomial order.
250+ - `use_bias`: Add learnable bias. Default `true`.
251+ - `init_weight`: Weights' initializer. Default `glorot_uniform`.
252+ - `init_state`: Initial state of the hidden stat of the GRU layer. Default `zeros32`.
253+ - `init_bias`: Bias initializer. Default `zeros32`.
254+
255+ # Examples
256+
257+ ```julia
258+ using GNNLux, Lux, Random
259+
260+ # initialize random number generator
261+ rng = Random.default_rng()
262+
263+ # create data
264+ g = rand_graph(rng, 5, 10)
265+ x = rand(rng, Float32, 2, 5)
266+
267+ # create A3TGCN layer
268+ l = GConvGRU(2 => 5, 2)
269+
270+ # setup layer
271+ ps, st = LuxCore.setup(rng, l)
272+
273+ # forward pass
274+ y, st = l(g, x, ps, st) # result size (5, 5)
275+ ```
276+ """
195277GConvGRU (ch:: Pair{Int, Int} , k:: Int ; kwargs... ) = GNNLux. StatefulRecurrentCell (GConvGRUCell (ch, k; kwargs... ))
196278
197279@concrete struct GConvLSTMCell <: GNNContainerLayer{(:conv_x_i, :conv_h_i, :dense_i, :conv_x_f, :conv_h_f, :dense_f, :conv_x_c, :conv_h_c, :dense_c, :conv_x_o, :conv_h_o, :dense_o)}
0 commit comments