@@ -405,21 +405,21 @@ struct DCGRUCell
405
405
in:: Int
406
406
out:: Int
407
407
state0
408
- K :: Int
408
+ k :: Int
409
409
dconv_u:: DConv
410
410
dconv_r:: DConv
411
411
dconv_c:: DConv
412
412
end
413
413
414
414
Flux. @functor DCGRUCell
415
415
416
- function DCGRUCell (ch:: Pair{Int,Int} , K :: Int , n:: Int ; bias = true , init = glorot_uniform, init_state = Flux. zeros32)
416
+ function DCGRUCell (ch:: Pair{Int,Int} , k :: Int , n:: Int ; bias = true , init = glorot_uniform, init_state = Flux. zeros32)
417
417
in, out = ch
418
- dconv_u = DConv ((in + out) => out, K ; bias= bias, init= init)
419
- dconv_r = DConv ((in + out) => out, K ; bias= bias, init= init)
420
- dconv_c = DConv ((in + out) => out, K ; bias= bias, init= init)
418
+ dconv_u = DConv ((in + out) => out, k ; bias= bias, init= init)
419
+ dconv_r = DConv ((in + out) => out, k ; bias= bias, init= init)
420
+ dconv_c = DConv ((in + out) => out, k ; bias= bias, init= init)
421
421
state0 = init_state (out, n)
422
- return DCGRUCell (in, out, state0, K , dconv_u, dconv_r, dconv_c)
422
+ return DCGRUCell (in, out, state0, k , dconv_u, dconv_r, dconv_c)
423
423
end
424
424
425
425
function (dcgru:: DCGRUCell )(h, g:: GNNGraph , x)
@@ -436,10 +436,48 @@ function (dcgru::DCGRUCell)(h, g::GNNGraph, x)
436
436
end
437
437
438
438
function Base. show (io:: IO , dcgru:: DCGRUCell )
439
- print (io, " DCGRUCell($(dcgru. in) => $(dcgru. out) , $(dcgru. K ) )" )
439
+ print (io, " DCGRUCell($(dcgru. in) => $(dcgru. out) , $(dcgru. k ) )" )
440
440
end
441
441
442
- DCGRU (ch, K, n; kwargs... ) = Flux. Recur (DCGRUCell (ch, K, n; kwargs... ))
442
+ """
443
+ DCGRU(in => out, k, n; [bias, init, init_state])
444
+
445
+ Diffusion Convolutional Recurrent Neural Network (DCGRU) layer from the paper [Diffusion Convolutional Recurrent Neural
446
+ Network: Data-driven Traffic Forecasting](https://arxiv.org/pdf/1707.01926).
447
+
448
+ Performs a Diffusion Convolutional layer to model spatial dependencies, followed by a Gated Recurrent Unit (GRU) cell to model temporal dependencies.
449
+
450
+ # Arguments
451
+
452
+ - `in`: Number of input features.
453
+ - `out`: Number of output features.
454
+ - `k`: Diffusion step.
455
+ - `n`: Number of nodes in the graph.
456
+ - `bias`: Add learnable bias. Default `true`.
457
+ - `init`: Weights' initializer. Default `glorot_uniform`.
458
+ - `init_state`: Initial state of the hidden stat of the LSTM layer. Default `zeros32`.
459
+
460
+ # Examples
461
+
462
+ ```jldoctest
463
+ julia> g1, x1 = rand_graph(5, 10), rand(Float32, 2, 5);
464
+
465
+ julia> dcgru = DCGRU(2 => 5, 2, g1.num_nodes);
466
+
467
+ julia> y = dcgru(g1, x1);
468
+
469
+ julia> size(y)
470
+ (5, 5)
471
+
472
+ julia> g2, x2 = rand_graph(5, 10), rand(Float32, 2, 5, 30);
473
+
474
+ julia> z = dcgru(g2, x2);
475
+
476
+ julia> size(z)
477
+ (5, 5, 30)
478
+ ```
479
+ """
480
+ DCGRU (ch, k, n; kwargs... ) = Flux. Recur (DCGRUCell (ch, k, n; kwargs... ))
443
481
Flux. Recur (dcgru:: DCGRUCell ) = Flux. Recur (dcgru, dcgru. state0)
444
482
445
483
(l:: Flux.Recur{DCGRUCell} )(g:: GNNGraph ) = GNNGraph (g, ndata = l (g, node_features (g)))
0 commit comments