2222""" 
2323    GNNRecurrence(cell) 
2424
25- Construct a recurrent layer that applies the `cell` 
26- to process an entire temporal sequence of node features at once. 
25+ Construct a recurrent layer that applies the graph recurrent `cell` forward 
26+ multiple times to process an entire temporal sequence of node features at once. 
27+ 
28+ The `cell` has to satisfy the following interface for the forward pass:  
29+ `yt, state = cell(g, xt, state)`, where `xt` is the input node features, 
30+ `yt` is the updated node features, `state` is the cell state to be updated. 
2731
2832# Forward  
2933
30-     layer(g::GNNGraph , x, [state]) 
34+     layer(g, x, [state]) 
3135
32- - `g`: The input graph. 
33- - `x`: The time-varying node features. An array of size `in x timesteps x num_nodes`. 
34- - `state`: The current state of the cell.  
36+ Applies the recurrent cell to each timestep of the input sequence. 
37+ 
38+ ## Arguments 
39+ 
40+ - `g`: The input `GNNGraph` or `TemporalSnapshotsGNNGraph`. 
41+     - If `GNNGraph`, the same graph is used for all timesteps. 
42+     - If `TemporalSnapshotsGNNGraph`, a different graph is used for each timestep. Not all cells support this. 
43+ - `x`: The time-varying node features.  
44+     - If `g` is `GNNGraph`, it is an array of size `in x timesteps x num_nodes`. 
45+     - If `g` is `TemporalSnapshotsGNNGraph`, it is an vector of length `timesteps`, 
46+       with element `t` of size `in x num_nodes_t`. 
47+ - `state`: The initial state for the cell.  
3548   If not provided, it is generated by calling `Flux.initialstates(cell)`. 
3649
37- Applies the recurrent cell to each timestep of the input sequence and returns the output as 
38- an array of size `out_features x timesteps x num_nodes`. 
50+ ## Return 
51+ 
52+ Returns the updated node features: 
53+ - If `g` is `GNNGraph`, returns an array of size `out_features x timesteps x num_nodes`. 
54+ - If `g` is `TemporalSnapshotsGNNGraph`, returns a vector of length `timesteps`,  
55+   with element `t` of size `out_features x num_nodes_t`. 
3956
4057# Examples 
4158
59+ The following example considers a static graph and a time-varying node features. 
60+ 
4261```jldoctest 
4362julia> num_nodes, num_edges = 5, 10; 
4463
@@ -47,6 +66,9 @@ julia> d_in, d_out = 2, 3;
4766julia> timesteps = 5; 
4867
4968julia> g = rand_graph(num_nodes, num_edges); 
69+ GNNGraph: 
70+   num_nodes: 5 
71+   num_edges: 10 
5072
5173julia> x = rand(Float32, d_in, timesteps, num_nodes); 
5274
@@ -63,6 +85,38 @@ julia> y = layer(g, x);
6385julia> size(y) # (d_out, timesteps, num_nodes) 
6486(3, 5, 5) 
6587``` 
88+ Now consider a time-varying graph and time-varying node features. 
89+ ```jldoctest 
90+ julia> d_in, d_out = 2, 3; 
91+ 
92+ julia> timesteps = 5; 
93+ 
94+ julia> num_nodes = [10, 10, 10, 10, 10]; 
95+ 
96+ julia> num_edges = [10, 12, 14, 16, 18]; 
97+ 
98+ julia> snapshots = [rand_graph(n, m) for (n, m) in zip(num_nodes, num_edges)]; 
99+ 
100+ julia> tg = TemporalSnapshotsGNNGraph(snapshots) 
101+ 
102+ julia> x = [rand(Float32, d_in, n) for n in num_nodes]; 
103+ 
104+ julia> cell = EvolveGCNOCell(d_in => d_out) 
105+ EvolveGCNOCell(2 => 3)  # 321 parameters 
106+ 
107+ julia> layer = GNNRecurrence(cell) 
108+ GNNRecurrence( 
109+   EvolveGCNOCell(2 => 3),               # 321 parameters 
110+ )                   # Total: 5 arrays, 321 parameters, 1.535 KiB. 
111+ 
112+ julia> y = layer(tg, x); 
113+ 
114+ julia> length(y)    # timesteps 
115+ 5 
116+ 
117+ julia> size(y[end]) # (d_out, num_nodes[end]) 
118+ (3, 10) 
119+ ``` 
66120""" 
67121struct  GNNRecurrence{G} <:  GNNLayer 
68122    cell:: G 
@@ -437,7 +491,7 @@ in combination with a Gated Recurrent Unit (GRU) cell to model temporal dependen
437491- `out`: Number of output node features. 
438492- `k`: Diffusion step for the `DConv`. 
439493- `bias`: Add learnable bias. Default `true`. 
440- - `init`: Weights ' initializer. Default `glorot_uniform`. 
494+ - `init`: Convolution weights ' initializer. Default `glorot_uniform`. 
441495
442496# Forward  
443497
651705function  Base. show (io:: IO , egcno:: EvolveGCNOCell )
652706    print (io, " EvolveGCNOCell($(egcno. in)  => $(egcno. out) )" 
653707end 
708+ 
709+ 
710+ """ 
711+     EvolveGCNO(args...; kws...) 
712+ 
713+ Construct a recurrent layer corresponding to the [`EvolveGCNOCell`](@ref) cell. 
714+ It can be used to process an entire temporal sequence of graphs and node features at once. 
715+ 
716+ The arguments are passed to the [`EvolveGCNOCell`](@ref) constructor. 
717+ See [`GNNRecurrence`](@ref) for more details. 
718+ 
719+ # Examples 
720+ 
721+ ```jldoctest 
722+ julia> d_in, d_out = 2, 3; 
723+ 
724+ julia> timesteps = 5; 
725+ 
726+ julia> num_nodes = [10, 10, 10, 10, 10]; 
727+ 
728+ julia> num_edges = [10, 12, 14, 16, 18]; 
729+ 
730+ julia> snapshots = [rand_graph(n, m) for (n, m) in zip(num_nodes, num_edges)]; 
731+ 
732+ julia> tg = TemporalSnapshotsGNNGraph(snapshots) 
733+ 
734+ julia> x = [rand(Float32, d_in, n) for n in num_nodes]; 
735+ 
736+ julia> cell = EvolveGCNO(d_in => d_out) 
737+ GNNRecurrence( 
738+   EvolveGCNOCell(2 => 3),               # 321 parameters 
739+ )                   # Total: 5 arrays, 321 parameters, 1.535 KiB. 
740+ 
741+ julia> y = layer(tg, x); 
742+ 
743+ julia> length(y)    # timesteps 
744+ 5  
745+ 
746+ julia> size(y[end]) # (d_out, num_nodes[end]) 
747+ (3, 10) 
748+ ``` 
749+ """ 
750+ EvolveGCNO (args... ; kws... ) =  GNNRecurrence (EvolveGCNOCell (args... ; kws... ))
0 commit comments