@@ -279,6 +279,128 @@ Flux.Recur(ggru::GConvGRUCell) = Flux.Recur(ggru, ggru.state0)
279
279
_applylayer (l:: Flux.Recur{GConvGRUCell} , g:: GNNGraph , x) = l (g, x)
280
280
_applylayer (l:: Flux.Recur{GConvGRUCell} , g:: GNNGraph ) = l (g)
281
281
282
+ struct GConvLSTMCell <: GNNLayer
283
+ conv_x_i:: ChebConv
284
+ conv_h_i:: ChebConv
285
+ w_i
286
+ b_i
287
+ conv_x_f:: ChebConv
288
+ conv_h_f:: ChebConv
289
+ w_f
290
+ b_f
291
+ conv_x_c:: ChebConv
292
+ conv_h_c:: ChebConv
293
+ w_c
294
+ b_c
295
+ conv_x_o:: ChebConv
296
+ conv_h_o:: ChebConv
297
+ w_o
298
+ b_o
299
+ k:: Int
300
+ state0
301
+ in:: Int
302
+ out:: Int
303
+ end
304
+
305
+ Flux. @functor GConvLSTMCell
306
+
307
+ function GConvLSTMCell (ch:: Pair{Int, Int} , k:: Int , n:: Int ;
308
+ bias:: Bool = true ,
309
+ init = Flux. glorot_uniform,
310
+ init_state = Flux. zeros32)
311
+ in, out = ch
312
+ # input gate
313
+ conv_x_i = ChebConv (in => out, k; bias, init)
314
+ conv_h_i = ChebConv (out => out, k; bias, init)
315
+ w_i = init (out, 1 )
316
+ b_i = bias ? Flux. create_bias (w_i, true , out) : false
317
+ # forget gate
318
+ conv_x_f = ChebConv (in => out, k; bias, init)
319
+ conv_h_f = ChebConv (out => out, k; bias, init)
320
+ w_f = init (out, 1 )
321
+ b_f = bias ? Flux. create_bias (w_f, true , out) : false
322
+ # cell state
323
+ conv_x_c = ChebConv (in => out, k; bias, init)
324
+ conv_h_c = ChebConv (out => out, k; bias, init)
325
+ w_c = init (out, 1 )
326
+ b_c = bias ? Flux. create_bias (w_c, true , out) : false
327
+ # output gate
328
+ conv_x_o = ChebConv (in => out, k; bias, init)
329
+ conv_h_o = ChebConv (out => out, k; bias, init)
330
+ w_o = init (out, 1 )
331
+ b_o = bias ? Flux. create_bias (w_o, true , out) : false
332
+ state0 = (init_state (out, n), init_state (out, n))
333
+ return GConvLSTMCell (conv_x_i, conv_h_i, w_i, b_i,
334
+ conv_x_f, conv_h_f, w_f, b_f,
335
+ conv_x_c, conv_h_c, w_c, b_c,
336
+ conv_x_o, conv_h_o, w_o, b_o,
337
+ k, state0, in, out)
338
+ end
339
+
340
+ function (gclstm:: GConvLSTMCell )((h, c), g:: GNNGraph , x)
341
+ # input gate
342
+ i = gclstm. conv_x_i (g, x) .+ gclstm. conv_h_i (g, h) .+ gclstm. w_i .* c .+ gclstm. b_i
343
+ i = Flux. sigmoid_fast (i)
344
+ # forget gate
345
+ f = gclstm. conv_x_f (g, x) .+ gclstm. conv_h_f (g, h) .+ gclstm. w_f .* c .+ gclstm. b_f
346
+ f = Flux. sigmoid_fast (f)
347
+ # cell state
348
+ c = f .* c .+ i .* Flux. tanh_fast (gclstm. conv_x_c (g, x) .+ gclstm. conv_h_c (g, h) .+ gclstm. w_c .* c .+ gclstm. b_c)
349
+ # output gate
350
+ o = gclstm. conv_x_o (g, x) .+ gclstm. conv_h_o (g, h) .+ gclstm. w_o .* c .+ gclstm. b_o
351
+ o = Flux. sigmoid_fast (o)
352
+ h = o .* Flux. tanh_fast (c)
353
+ return (h,c), h
354
+ end
355
+
356
+ function Base. show (io:: IO , gclstm:: GConvLSTMCell )
357
+ print (io, " GConvLSTMCell($(gclstm. in) => $(gclstm. out) )" )
358
+ end
359
+
360
+ """
361
+ GConvLSTM(in => out, k, n; [bias, init, init_state])
362
+
363
+ Graph Convolutional Long Short-Term Memory (GConvLSTM) recurrent layer from the paper [Structured Sequence Modeling with Graph Convolutional Recurrent Networks](https://arxiv.org/pdf/1612.07659).
364
+
365
+ Performs a layer of ChebConv to model spatial dependencies, followed by a Long Short-Term Memory (LSTM) cell to model temporal dependencies.
366
+
367
+ # Arguments
368
+
369
+ - `in`: Number of input features.
370
+ - `out`: Number of output features.
371
+ - `k`: Chebyshev polynomial order.
372
+ - `n`: Number of nodes in the graph.
373
+ - `bias`: Add learnable bias. Default `true`.
374
+ - `init`: Weights' initializer. Default `glorot_uniform`.
375
+ - `init_state`: Initial state of the hidden stat of the LSTM layer. Default `zeros32`.
376
+
377
+ # Examples
378
+
379
+ ```jldoctest
380
+ julia> g1, x1 = rand_graph(5, 10), rand(Float32, 2, 5);
381
+
382
+ julia> gclstm = GConvLSTM(2 => 5, 2, g1.num_nodes);
383
+
384
+ julia> y = gclstm(g1, x1);
385
+
386
+ julia> size(y)
387
+ (5, 5)
388
+
389
+ julia> g2, x2 = rand_graph(5, 10), rand(Float32, 2, 5, 30);
390
+
391
+ julia> z = gclstm(g2, x2);
392
+
393
+ julia> size(z)
394
+ (5, 5, 30)
395
+ ```
396
+ """
397
+ GConvLSTM (ch, k, n; kwargs... ) = Flux. Recur (GConvLSTMCell (ch, k, n; kwargs... ))
398
+ Flux. Recur (tgcn:: GConvLSTMCell ) = Flux. Recur (tgcn, tgcn. state0)
399
+
400
+ (l:: Flux.Recur{GConvLSTMCell} )(g:: GNNGraph ) = GNNGraph (g, ndata = l (g, node_features (g)))
401
+ _applylayer (l:: Flux.Recur{GConvLSTMCell} , g:: GNNGraph , x) = l (g, x)
402
+ _applylayer (l:: Flux.Recur{GConvLSTMCell} , g:: GNNGraph ) = l (g)
403
+
282
404
function (l:: GINConv )(tg:: TemporalSnapshotsGNNGraph , x:: AbstractVector )
283
405
return l .(tg. snapshots, x)
284
406
end
0 commit comments