@@ -401,6 +401,89 @@ Flux.Recur(tgcn::GConvLSTMCell) = Flux.Recur(tgcn, tgcn.state0)
401
401
_applylayer (l:: Flux.Recur{GConvLSTMCell} , g:: GNNGraph , x) = l (g, x)
402
402
_applylayer (l:: Flux.Recur{GConvLSTMCell} , g:: GNNGraph ) = l (g)
403
403
404
+ struct DCGRUCell
405
+ in:: Int
406
+ out:: Int
407
+ state0
408
+ k:: Int
409
+ dconv_u:: DConv
410
+ dconv_r:: DConv
411
+ dconv_c:: DConv
412
+ end
413
+
414
+ Flux. @functor DCGRUCell
415
+
416
+ function DCGRUCell (ch:: Pair{Int,Int} , k:: Int , n:: Int ; bias = true , init = glorot_uniform, init_state = Flux. zeros32)
417
+ in, out = ch
418
+ dconv_u = DConv ((in + out) => out, k; bias= bias, init= init)
419
+ dconv_r = DConv ((in + out) => out, k; bias= bias, init= init)
420
+ dconv_c = DConv ((in + out) => out, k; bias= bias, init= init)
421
+ state0 = init_state (out, n)
422
+ return DCGRUCell (in, out, state0, k, dconv_u, dconv_r, dconv_c)
423
+ end
424
+
425
+ function (dcgru:: DCGRUCell )(h, g:: GNNGraph , x)
426
+ h̃ = vcat (x, h)
427
+ z = dcgru. dconv_u (g, h̃)
428
+ z = NNlib. sigmoid_fast .(z)
429
+ r = dcgru. dconv_r (g, h̃)
430
+ r = NNlib. sigmoid_fast .(r)
431
+ ĥ = vcat (x, h .* r)
432
+ c = dcgru. dconv_c (g, ĥ)
433
+ c = tanh .(c)
434
+ h = z.* h + (1 .- z) .* c
435
+ return h, h
436
+ end
437
+
438
+ function Base. show (io:: IO , dcgru:: DCGRUCell )
439
+ print (io, " DCGRUCell($(dcgru. in) => $(dcgru. out) , $(dcgru. k) )" )
440
+ end
441
+
442
+ """
443
+ DCGRU(in => out, k, n; [bias, init, init_state])
444
+
445
+ Diffusion Convolutional Recurrent Neural Network (DCGRU) layer from the paper [Diffusion Convolutional Recurrent Neural
446
+ Network: Data-driven Traffic Forecasting](https://arxiv.org/pdf/1707.01926).
447
+
448
+ Performs a Diffusion Convolutional layer to model spatial dependencies, followed by a Gated Recurrent Unit (GRU) cell to model temporal dependencies.
449
+
450
+ # Arguments
451
+
452
+ - `in`: Number of input features.
453
+ - `out`: Number of output features.
454
+ - `k`: Diffusion step.
455
+ - `n`: Number of nodes in the graph.
456
+ - `bias`: Add learnable bias. Default `true`.
457
+ - `init`: Weights' initializer. Default `glorot_uniform`.
458
+ - `init_state`: Initial state of the hidden stat of the LSTM layer. Default `zeros32`.
459
+
460
+ # Examples
461
+
462
+ ```jldoctest
463
+ julia> g1, x1 = rand_graph(5, 10), rand(Float32, 2, 5);
464
+
465
+ julia> dcgru = DCGRU(2 => 5, 2, g1.num_nodes);
466
+
467
+ julia> y = dcgru(g1, x1);
468
+
469
+ julia> size(y)
470
+ (5, 5)
471
+
472
+ julia> g2, x2 = rand_graph(5, 10), rand(Float32, 2, 5, 30);
473
+
474
+ julia> z = dcgru(g2, x2);
475
+
476
+ julia> size(z)
477
+ (5, 5, 30)
478
+ ```
479
+ """
480
+ DCGRU (ch, k, n; kwargs... ) = Flux. Recur (DCGRUCell (ch, k, n; kwargs... ))
481
+ Flux. Recur (dcgru:: DCGRUCell ) = Flux. Recur (dcgru, dcgru. state0)
482
+
483
+ (l:: Flux.Recur{DCGRUCell} )(g:: GNNGraph ) = GNNGraph (g, ndata = l (g, node_features (g)))
484
+ _applylayer (l:: Flux.Recur{DCGRUCell} , g:: GNNGraph , x) = l (g, x)
485
+ _applylayer (l:: Flux.Recur{DCGRUCell} , g:: GNNGraph ) = l (g)
486
+
404
487
function (l:: GINConv )(tg:: TemporalSnapshotsGNNGraph , x:: AbstractVector )
405
488
return l .(tg. snapshots, x)
406
489
end
0 commit comments