@@ -7,6 +7,73 @@ function scan(cell, g::GNNGraph, x::AbstractArray{T,3}, state) where {T}
77    return  stack (y, dims =  2 )
88end 
99
10+ 
11+ """ 
12+     GNNRecurrence(cell) 
13+ 
14+ Construct a recurrent layer that applies the `cell` 
15+ to process an entire temporal sequence of node features at once. 
16+ 
17+ # Forward  
18+ 
19+     layer(g::GNNGraph, x, [state]) 
20+ 
21+ - `g`: The input graph. 
22+ - `x`: The time-varying node features. An array of size `in x timesteps x num_nodes`. 
23+ - `state`: The initial state of the cell.  
24+    If not provided, it is generated by calling `Flux.initialstates(cell)`. 
25+ 
26+ Applies the recurrent cell to each timestep of the input sequence and returns the output as 
27+ an array of size `out_features x timesteps x num_nodes`. 
28+ 
29+ # Examples 
30+ 
31+ ```jldoctest 
32+ julia> num_nodes, num_edges = 5, 10; 
33+ 
34+ julia> d_in, d_out = 2, 3; 
35+ 
36+ julia> timesteps = 5; 
37+ 
38+ julia> g = rand_graph(num_nodes, num_edges); 
39+ 
40+ julia> x = rand(Float32, d_in, timesteps, num_nodes); 
41+ 
42+ julia> cell = GConvLSTMCell(d_in => d_out, 2) 
43+ GConvLSTMCell(2 => 3, 2)  # 168 parameters 
44+ 
45+ julia> layer = GNNRecurrence(cell) 
46+ GNNRecurrence( 
47+   GConvLSTMCell(2 => 3, 2),             # 168 parameters 
48+ )                   # Total: 24 arrays, 168 parameters, 2.023 KiB. 
49+ 
50+ julia> y = layer(g, x); 
51+ 
52+ julia> size(y) # (d_out, timesteps, num_nodes) 
53+ (3, 5, 5) 
54+ ``` 
55+ """ 
56+ struct  GNNRecurrence{G} <:  GNNLayer 
57+     cell:: G 
58+ end 
59+ 
60+ Flux. @layer  GNNRecurrence
61+ 
62+ Flux. initialstates (rnn:: GNNRecurrence ) =  Flux. initialstates (rnn. cell)
63+ 
64+ function  (rnn:: GNNRecurrence )(g:: GNNGraph , x:: AbstractArray{T,3} ) where  {T}
65+     return  rnn (g, x, initialstates (rnn))
66+ end 
67+ 
68+ function  (rnn:: GNNRecurrence )(g:: GNNGraph , x:: AbstractArray{T,3} , state) where  {T}
69+     return  scan (rnn. cell, g, x, state)
70+ end 
71+ 
72+ function  Base. show (io:: IO , rnn:: GNNRecurrence )
73+     print (io, " GNNRecurrence($(rnn. cell) )" 
74+ end 
75+ 
76+ 
1077""" 
1178    GConvGRUCell(in => out, k; [bias, init]) 
1279
@@ -126,24 +193,13 @@ function Base.show(io::IO, cell::GConvGRUCell)
126193end 
127194
128195""" 
129-     GConvGRU(in => out, k; kws...) 
130- 
131- The recurrent layer corresponding to the [`GConvGRUCell`](@ref) cell,  
132- used to process an entire temporal sequence of node features at once. 
196+     GConvGRU(args...; kws...) 
133197
134- The arguments are the same as for [`GConvGRUCell`](@ref). 
198+ Construct a recurrent layer corresponding to the [`GConvGRUCell`](@ref) cell. 
199+ It can be used to process an entire temporal sequence of node features at once. 
135200
136- # Forward  
137- 
138-     layer(g::GNNGraph, x, [h]) 
139- 
140- - `g`: The input graph. 
141- - `x`: The time-varying node features. It should be an array of size `in x timesteps x num_nodes`. 
142- - `h`: The initial hidden state of the GRU cell. If given, it is a matrix of size `out x num_nodes`. 
143-        If not provided, it is assumed to be a matrix of zeros. 
144- 
145- Applies the recurrent cell to each timestep of the input sequence and returns the output as 
146- an array of size `out x timesteps x num_nodes`. 
201+ The arguments are passed to the [`GConvGRUCell`](@ref) constructor. 
202+ See [`GNNRecurrence`](@ref) for more details. 
147203
148204# Examples 
149205
@@ -158,33 +214,18 @@ julia> g = rand_graph(num_nodes, num_edges);
158214
159215julia> x = rand(Float32, d_in, timesteps, num_nodes); 
160216
161- julia> layer = GConvGRU(d_in => d_out, 2); 
217+ julia> layer = GConvGRU(d_in => d_out, 2) 
218+ GConvGRU( 
219+   GConvGRUCell(2 => 3, 2),              # 108 parameters 
220+ )                   # Total: 12 arrays, 108 parameters, 1.148 KiB. 
162221
163222julia> y = layer(g, x); 
164223
165224julia> size(y) # (d_out, timesteps, num_nodes) 
166225(3, 5, 5) 
167226``` 
168227""" 
169- struct  GConvGRU{G<: GConvGRUCell } <:  GNNLayer 
170-     cell:: G 
171- end 
172- 
173- Flux. @layer  GConvGRU
174- 
175- function  GConvGRU (ch:: Pair{Int,Int} , k:: Int ; kws... )
176-     return  GConvGRU (GConvGRUCell (ch, k; kws... ))
177- end 
178- 
179- Flux. initialstates (rnn:: GConvGRU ) =  Flux. initialstates (rnn. cell)
180- 
181- function  (rnn:: GConvGRU )(g:: GNNGraph , x:: AbstractArray )
182-     return  scan (rnn. cell, g, x, initialstates (rnn))
183- end 
184- 
185- function  Base. show (io:: IO , rnn:: GConvGRU )
186-     print (io, " GConvGRU($(rnn. cell. in)  => $(rnn. cell. out) , $(rnn. cell. k) )" 
187- end 
228+ GConvGRU (args... ; kws... ) =  GNNRecurrence (GConvGRUCell (args... ; kws... ))
188229
189230
190231""" 
@@ -268,7 +309,7 @@ julia> size(y) # (d_out, num_nodes)
268309    out:: Int 
269310end 
270311
271- Flux. @layer  GConvLSTMCell
312+ Flux. @layer  :noexpand   GConvLSTMCell
272313
273314function  GConvLSTMCell (ch:: Pair{Int, Int} , k:: Int ;
274315                        bias:: Bool  =  true ,
@@ -305,6 +346,8 @@ function Flux.initialstates(cell::GConvLSTMCell)
305346    (zeros_like (cell. conv_x_i. weight, cell. out), zeros_like (cell. conv_x_i. weight, cell. out))
306347end 
307348
349+ (cell:: GConvLSTMCell )(g:: GNNGraph , x:: AbstractMatrix ) =  cell (g, x, initialstates (cell))
350+ 
308351function  (cell:: GConvLSTMCell )(g:: GNNGraph , x:: AbstractMatrix , (h, c))
309352    if  h isa  AbstractVector
310353        h =  repeat (h, 1 , g. num_nodes)
@@ -334,29 +377,74 @@ end
334377
335378
336379""" 
337-     GConvLSTM(in => out, k; kws...) 
380+     GConvLSTM(args...; kws...) 
381+ 
382+ Construct a recurrent layer corresponding to the [`GConvLSTMCell`](@ref) cell. 
383+ It can be used to process an entire temporal sequence of node features at once. 
384+ 
385+ The arguments are passed to the [`GConvLSTMCell`](@ref) constructor. 
386+ See [`GNNRecurrence`](@ref) for more details. 
387+ 
388+ # Examples 
389+ 
390+ ```jldoctest 
391+ julia> num_nodes, num_edges = 5, 10; 
392+ 
393+ julia> d_in, d_out = 2, 3; 
394+ 
395+ julia> timesteps = 5; 
396+ 
397+ julia> g = rand_graph(num_nodes, num_edges); 
398+ 
399+ julia> x = rand(Float32, d_in, timesteps, num_nodes); 
400+ 
401+ julia> layer = GConvLSTM(d_in => d_out, 2) 
402+ GNNRecurrence( 
403+   GConvLSTMCell(2 => 3, 2),             # 168 parameters 
404+ )                   # Total: 24 arrays, 168 parameters, 2.023 KiB. 
405+ 
406+ julia> y = layer(g, x); 
407+ 
408+ julia> size(y) # (d_out, timesteps, num_nodes) 
409+ (3, 5, 5) 
410+ ``` 
411+ """ 
412+ GConvLSTM (args... ; kws... ) =  GNNRecurrence (GConvLSTMCell (args... ; kws... ))
413+ 
414+ """ 
415+     DCGRUCell(in => out, k; [bias, init]) 
416+ 
417+ Diffusion Convolutional Recurrent Neural Network (DCGRU) cell from the paper  
418+ [Diffusion Convolutional Recurrent Neural Network: Data-driven Traffic Forecasting](https://arxiv.org/abs/1707.01926). 
419+ 
420+ Applyis a [`DConv`](@ref) layer to model spatial dependencies,  
421+ in combination with a Gated Recurrent Unit (GRU) cell to model temporal dependencies. 
338422
339- The recurrent layer corresponding to the [`GConvLSTMCell`](@ref) cell,  
340- used to process an entire temporal sequence of node features at once. 
423+ # Arguments 
341424
342- The arguments are the same as for [`GConvLSTMCell`](@ref). 
425+ - `in`: Number of input node features. 
426+ - `out`: Number of output node features. 
427+ - `k`: Diffusion step for the `DConv`. 
428+ - `bias`: Add learnable bias. Default `true`. 
429+ - `init`: Weights' initializer. Default `glorot_uniform`. 
343430
344431# Forward  
345432
346-     layer (g::GNNGraph, x, [state ]) 
433+     cell (g::GNNGraph, x, [h ]) 
347434
348435- `g`: The input graph. 
349- - `x`: The time-varying node features. It should be an array of size `in x timesteps x num_nodes`. 
350- - `state`: The initial hidden state of the LSTM cell.  
351-       If given, it is a tuple `(h, c)` where both elements are matrices of size `out x num_nodes`. 
352-       If not provided, the initial hidden state is assumed to be a tuple of matrices of zeros. 
436+ - `x`: The node features. It should be a matrix of size `in x num_nodes`. 
437+ - `h`: The initial hidden state of the GRU cell. If given, it is a matrix of size `out x num_nodes`. 
438+        If not provided, it is assumed to be a matrix of zeros. 
353439
354- Applies the recurrent cell to each timestep of the input sequence  and returns the output as 
355- an array of size `out x timesteps x num_nodes` .
440+ Performs one recurrence step  and returns a tuple `(h, h)`, 
441+ where `h` is the updated hidden state of the GRU cell .
356442
357443# Examples 
358444
359445```jldoctest 
446+ julia> using GraphNeuralNetworks, Flux 
447+ 
360448julia> num_nodes, num_edges = 5, 10; 
361449
362450julia> d_in, d_out = 2, 3; 
@@ -365,33 +453,98 @@ julia> timesteps = 5;
365453
366454julia> g = rand_graph(num_nodes, num_edges); 
367455
368- julia> x = rand(Float32, d_in, timesteps,  num_nodes); 
456+ julia> x = [ rand(Float32, d_in, num_nodes) for t in 1:timesteps] ; 
369457
370- julia> layer  = GConvLSTM (d_in => d_out, 2); 
458+ julia> cell  = DCGRUCell (d_in => d_out, 2); 
371459
372- julia> y  = layer(g, x ); 
460+ julia> state  = Flux.initialstates(cell ); 
373461
374- julia> size(y) # (d_out, timesteps, num_nodes) 
375- (3, 5, 5) 
462+ julia> y = state; 
463+ 
464+ julia> for xt in x 
465+            y, state = cell(g, xt, state) 
466+        end 
467+ 
468+ julia> size(y) # (d_out, num_nodes) 
469+ (3, 5) 
376470``` 
377- """ 
378- struct  GConvLSTM{G<: GConvLSTMCell } <:  GNNLayer 
379-     cell:: G 
471+ """ 
472+ struct  DCGRUCell
473+     in:: Int 
474+     out:: Int 
475+     k:: Int 
476+     dconv_u:: DConv 
477+     dconv_r:: DConv 
478+     dconv_c:: DConv 
380479end 
381480
382- Flux. @layer  GConvLSTM 
481+ Flux. @layer  :noexpand  DCGRUCell 
383482
384- function  GConvLSTM (ch:: Pair{Int,Int} , k:: Int ; kws... )
385-     return  GConvLSTM (GConvLSTMCell (ch, k; kws... ))
483+ function  DCGRUCell (ch:: Pair{Int,Int} , k:: Int ; bias =  true , init =  glorot_uniform)
484+     in, out =  ch
485+     dconv_u =  DConv ((in +  out) =>  out, k; bias, init)
486+     dconv_r =  DConv ((in +  out) =>  out, k; bias, init)
487+     dconv_c =  DConv ((in +  out) =>  out, k; bias, init)
488+     return  DCGRUCell (in, out, k, dconv_u, dconv_r, dconv_c)
386489end 
387490
388- Flux. initialstates (rnn:: GConvLSTM ) =  Flux. initialstates (rnn. cell)
491+ Flux. initialstates (cell:: DCGRUCell ) =  zeros_like (cell. dconv_u. weights, cell. out)
492+ 
493+ (cell:: DCGRUCell )(g:: GNNGraph , x:: AbstractMatrix ) =  cell (g, x, initialstates (cell))
389494
390- function  (rnn :: GConvLSTM )(g:: GNNGraph , x:: AbstractArray ) 
391-     return  scan (rnn . cell,  g, x, initialstates (rnn ))
495+ function  (cell :: DCGRUCell )(g:: GNNGraph , x:: AbstractMatrix , h :: AbstractVector )  
496+     return  cell ( g, x, repeat (h,  1 , g . num_nodes ))
392497end 
393498
394- function  Base. show (io:: IO , rnn:: GConvLSTM )
395-     print (io, " GConvLSTM($(rnn. cell. in)  => $(rnn. cell. out) , $(rnn. cell. k) )" 
499+ function  (cell:: DCGRUCell )(g:: GNNGraph , x:: AbstractMatrix , h:: AbstractMatrix )
500+     h̃ =  vcat (x, h)
501+     z =  cell. dconv_u (g, h̃)
502+     z =  NNlib. sigmoid_fast .(z)
503+     r =  cell. dconv_r (g, h̃)
504+     r =  NNlib. sigmoid_fast .(r)
505+     ĥ =  vcat (x, h .*  r)
506+     c =  cell. dconv_c (g, ĥ)
507+     c =  NNlib. tanh_fast .(c)
508+     h =  z.*  h +  (1  .-  z) .*  c
509+     return  h, h
510+ end 
511+ 
512+ function  Base. show (io:: IO , cell:: DCGRUCell )
513+     print (io, " DCGRUCell($(cell. in)  => $(cell. out) , $(cell. k) )" 
396514end 
397515
516+ """ 
517+     DCGRU(args...; kws...) 
518+ 
519+ Construct a recurrent layer corresponding to the [`DCGRUCell`](@ref) cell. 
520+ It can be used to process an entire temporal sequence of node features at once. 
521+ 
522+ The arguments are passed to the [`DCGRUCell`](@ref) constructor. 
523+ See [`GNNRecurrence`](@ref) for more details. 
524+ 
525+ # Examples 
526+ ```jldoctest 
527+ julia> num_nodes, num_edges = 5, 10; 
528+ 
529+ julia> d_in, d_out = 2, 3; 
530+ 
531+ julia> timesteps = 5; 
532+ 
533+ julia> g = rand_graph(num_nodes, num_edges); 
534+ 
535+ julia> x = rand(Float32, d_in, timesteps, num_nodes); 
536+ 
537+ julia> layer = DCGRU(d_in => d_out, 2) 
538+ GNNRecurrence( 
539+   DCGRUCell(2 => 3, 2),                 # 189 parameters 
540+ )                   # Total: 6 arrays, 189 parameters, 1.184 KiB. 
541+ 
542+ julia> y = layer(g, x); 
543+ 
544+ julia> size(y) # (d_out, timesteps, num_nodes) 
545+ (3, 5, 5) 
546+ ``` 
547+ """ 
548+ DCGRU (args... ; kws... ) =  GNNRecurrence (DCGRUCell (args... ; kws... ))
549+ 
550+ 
0 commit comments