@@ -130,7 +130,7 @@ function (rnn::GNNRecurrence)(g, x)
130130    return  rnn (g, x, initialstates (rnn))
131131end 
132132
133- function  (rnn:: GNNRecurrence )(g, x, state)  where  {T} 
133+ function  (rnn:: GNNRecurrence )(g, x, state)
134134    return  scan (rnn. cell, g, x, state)
135135end 
136136
@@ -750,7 +750,60 @@ julia> size(y[end]) # (d_out, num_nodes[end])
750750EvolveGCNO (args... ; kws... ) =  GNNRecurrence (EvolveGCNOCell (args... ; kws... ))
751751
752752
753+ """ 
754+     TGCNCell(in => out; kws...) 
755+ 
756+ Recurrent graph convolutional cell from the paper 
757+ [T-GCN: A Temporal Graph Convolutional 
758+ Network for Traffic Prediction](https://arxiv.org/pdf/1811.05320). 
759+ 
760+ Uses two stacked [`GCNConv`](@ref) layers to model spatial dependencies, 
761+ and a GRU mechanism to model temporal dependencies. 
762+ 
763+ `in` and `out` are the number of input and output node features, respectively. 
764+ The keyword arguments are passed to the [`GCNConv`](@ref) constructor. 
765+ 
766+ # Forward  
767+ 
768+     cell(g::GNNGraph, x, [state]) 
769+ 
770+ - `g`: The input graph. 
771+ - `x`: The node features. It should be a matrix of size `in x num_nodes`. 
772+ - `state`: The current state of the cell. 
773+     If not provided, it is generated by calling `Flux.initialstates(cell)`. 
774+     The state is a matrix of size `out x num_nodes`. 
775+ 
776+ Returns the updated node features and the updated state. 
777+ 
778+ # Examples 
779+ 
780+ ```jldoctest 
781+ julia> using GraphNeuralNetworks, Flux 
782+ 
783+ julia> num_nodes, num_edges = 5, 10; 
784+ 
785+ julia> d_in, d_out = 2, 3; 
786+ 
787+ julia> timesteps = 5; 
788+ 
789+ julia> g = rand_graph(num_nodes, num_edges); 
753790
791+ julia> x = [rand(Float32, d_in, num_nodes) for t in 1:timesteps]; 
792+ 
793+ julia> cell = DCGRUCell(d_in => d_out, 2); 
794+ 
795+ julia> state = Flux.initialstates(cell); 
796+ 
797+ julia> y = state; 
798+ 
799+ julia> for xt in x 
800+            y, state = cell(g, xt, state) 
801+        end 
802+ 
803+ julia> size(y) # (d_out, num_nodes) 
804+ (3, 5) 
805+ ``` 
806+ """ 
754807@concrete  struct  TGCNCell <:  GNNLayer 
755808    in:: Int 
756809    out:: Int 
795848
796849function  Base. show (io:: IO , cell:: TGCNCell )
797850    print (io, " TGCNCell($(cell. in)  => $(cell. out) )" 
798- end 
851+ end 
852+ 
853+ """ 
854+     TGCN(args...; kws...) 
855+ 
856+ Construct a recurrent layer corresponding to the [`TGCNCell`](@ref) cell. 
857+ 
858+ The arguments are passed to the [`TGCNCell`](@ref) constructor. 
859+ See [`GNNRecurrence`](@ref) for more details. 
860+ 
861+ # Examples 
862+ 
863+ ```jldoctest 
864+ julia> num_nodes, num_edges = 5, 10; 
865+ 
866+ julia> d_in, d_out = 2, 3; 
867+ 
868+ julia> timesteps = 5; 
869+ 
870+ julia> g = rand_graph(num_nodes, num_edges); 
871+ 
872+ julia> x = rand(Float32, d_in, timesteps, num_nodes); 
873+ 
874+ julia> layer = TGCN(d_in => d_out) 
875+ 
876+ julia> y = layer(g, x); 
877+ 
878+ julia> size(y) # (d_out, timesteps, num_nodes) 
879+ (3, 5, 5) 
880+ ``` 
881+ """ 
882+ TGCN (args... ; kws... ) =  GNNRecurrence (TGCNCell (args... ; kws... ))
883+ 
0 commit comments