@@ -187,6 +187,98 @@ function Base.show(io::IO, a3tgcn::A3TGCN)
187
187
print (io, " A3TGCN($(a3tgcn. in) => $(a3tgcn. out) )" )
188
188
end
189
189
190
+ struct GConvGRUCell <: GNNLayer
191
+ conv_x_r:: ChebConv
192
+ conv_h_r:: ChebConv
193
+ conv_x_z:: ChebConv
194
+ conv_h_z:: ChebConv
195
+ conv_x_h:: ChebConv
196
+ conv_h_h:: ChebConv
197
+ k:: Int
198
+ state0
199
+ in:: Int
200
+ out:: Int
201
+ end
202
+
203
+ Flux. @functor GConvGRUCell
204
+
205
+ function GConvGRUCell (ch:: Pair{Int, Int} , k:: Int , n:: Int ;
206
+ bias:: Bool = true ,
207
+ init = Flux. glorot_uniform,
208
+ init_state = Flux. zeros32)
209
+ in, out = ch
210
+ # reset gate
211
+ conv_x_r = ChebConv (in => out, k; bias, init)
212
+ conv_h_r = ChebConv (out => out, k; bias, init)
213
+ # update gate
214
+ conv_x_z = ChebConv (in => out, k; bias, init)
215
+ conv_h_z = ChebConv (out => out, k; bias, init)
216
+ # new gate
217
+ conv_x_h = ChebConv (in => out, k; bias, init)
218
+ conv_h_h = ChebConv (out => out, k; bias, init)
219
+ state0 = init_state (out, n)
220
+ return GConvGRUCell (conv_x_r, conv_h_r, conv_x_z, conv_h_z, conv_x_h, conv_h_h, k, state0, in, out)
221
+ end
222
+
223
+ function (ggru:: GConvGRUCell )(h, g:: GNNGraph , x)
224
+ r = ggru. conv_x_r (g, x) .+ ggru. conv_h_r (g, h)
225
+ r = Flux. sigmoid_fast (r)
226
+ z = ggru. conv_x_z (g, x) .+ ggru. conv_h_z (g, h)
227
+ z = Flux. sigmoid_fast (z)
228
+ h̃ = ggru. conv_x_h (g, x) .+ ggru. conv_h_h (g, r .* h)
229
+ h̃ = Flux. tanh_fast (h̃)
230
+ h = (1 .- z) .* h̃ .+ z .* h
231
+ return h, h
232
+ end
233
+
234
+ function Base. show (io:: IO , ggru:: GConvGRUCell )
235
+ print (io, " GConvGRUCell($(ggru. in) => $(ggru. out) )" )
236
+ end
237
+
238
+ """
239
+ GConvGRU(in => out, k, n; [bias, init, init_state])
240
+
241
+ Graph Convolutional Gated Recurrent Unit (GConvGRU) recurrent layer from the paper [Structured Sequence Modeling with Graph Convolutional Recurrent Networks](https://arxiv.org/pdf/1612.07659).
242
+
243
+ Performs a layer of ChebConv to model spatial dependencies, followed by a Gated Recurrent Unit (GRU) cell to model temporal dependencies.
244
+
245
+ # Arguments
246
+
247
+ - `in`: Number of input features.
248
+ - `out`: Number of output features.
249
+ - `k`: Chebyshev polynomial order.
250
+ - `n`: Number of nodes in the graph.
251
+ - `bias`: Add learnable bias. Default `true`.
252
+ - `init`: Weights' initializer. Default `glorot_uniform`.
253
+ - `init_state`: Initial state of the hidden stat of the GRU layer. Default `zeros32`.
254
+
255
+ # Examples
256
+
257
+ ```jldoctest
258
+ julia> g1, x1 = rand_graph(5, 10), rand(Float32, 2, 5);
259
+
260
+ julia> ggru = GConvGRU(2 => 5, 2, g1.num_nodes);
261
+
262
+ julia> y = ggru(g1, x1);
263
+
264
+ julia> size(y)
265
+ (5, 5)
266
+
267
+ julia> g2, x2 = rand_graph(5, 10), rand(Float32, 2, 5, 30);
268
+
269
+ julia> z = ggru(g2, x2);
270
+
271
+ julia> size(z)
272
+ (5, 5, 30)
273
+ ```
274
+ """
275
+ GConvGRU (ch, k, n; kwargs... ) = Flux. Recur (GConvGRUCell (ch, k, n; kwargs... ))
276
+ Flux. Recur (ggru:: GConvGRUCell ) = Flux. Recur (ggru, ggru. state0)
277
+
278
+ (l:: Flux.Recur{GConvGRUCell} )(g:: GNNGraph ) = GNNGraph (g, ndata = l (g, node_features (g)))
279
+ _applylayer (l:: Flux.Recur{GConvGRUCell} , g:: GNNGraph , x) = l (g, x)
280
+ _applylayer (l:: Flux.Recur{GConvGRUCell} , g:: GNNGraph ) = l (g)
281
+
190
282
function (l:: GINConv )(tg:: TemporalSnapshotsGNNGraph , x:: AbstractVector )
191
283
return l .(tg. snapshots, x)
192
284
end
0 commit comments