@@ -166,7 +166,7 @@ julia> size(y) # (d_out, timesteps, num_nodes)
166166(3, 5, 5) 
167167``` 
168168""" 
169- struct  GConvGRU{G  <:   GConvGRUCell } <:  GNNLayer 
169+ struct  GConvGRU{G<: GConvGRUCell } <:  GNNLayer 
170170    cell:: G 
171171end 
172172
@@ -186,3 +186,212 @@ function Base.show(io::IO, rnn::GConvGRU)
186186    print (io, " GConvGRU($(rnn. cell. in)  => $(rnn. cell. out) , $(rnn. cell. k) )" 
187187end 
188188
189+ 
190+ """ 
191+     GConvLSTMCell(in => out, k; [bias, init]) 
192+ 
193+ Graph Convolutional LSTM recurrent cell from the paper  
194+ [Structured Sequence Modeling with Graph Convolutional Recurrent Networks](https://arxiv.org/abs/1612.07659). 
195+ 
196+ Uses [`ChebConv`](@ref) to model spatial dependencies,  
197+ followed by a Long Short-Term Memory (LSTM) cell to model temporal dependencies. 
198+ 
199+ # Arguments 
200+ 
201+ - `in => out`: A pair  where `in` is the number of input node features and `out`  
202+   is the number of output node features. 
203+ - `k`: Chebyshev polynomial order. 
204+ - `bias`: Add learnable bias. Default `true`. 
205+ - `init`: Weights' initializer. Default `glorot_uniform`. 
206+ 
207+ # Forward  
208+ 
209+     cell(g::GNNGraph, x, [state]) 
210+ 
211+ - `g`: The input graph. 
212+ - `x`: The node features. It should be a matrix of size `in x num_nodes`. 
213+ - `state`: The initial hidden state of the LSTM cell.   
214+        If given, it is a tuple `(h, c)` where both `h` and `c` are arrays of size `out x num_nodes`. 
215+        If not provided, the initial hidden state is assumed to be a tuple of matrices of zeros. 
216+ 
217+ Performs one recurrence step and returns a tuple `(output, state)`,  
218+ where `output` is the updated hidden state `h` of the LSTM cell and `state` is the updated tuple `(h, c)`. 
219+ 
220+ # Examples 
221+ 
222+ ```jldoctest 
223+ julia> using GraphNeuralNetworks, Flux 
224+ 
225+ julia> num_nodes, num_edges = 5, 10; 
226+ 
227+ julia> d_in, d_out = 2, 3; 
228+ 
229+ julia> timesteps = 5; 
230+ 
231+ julia> g = rand_graph(num_nodes, num_edges); 
232+ 
233+ julia> x = [rand(Float32, d_in, num_nodes) for t in 1:timesteps]; 
234+ 
235+ julia> cell = GConvLSTMCell(d_in => d_out, 2); 
236+ 
237+ julia> state = Flux.initialstates(cell); 
238+ 
239+ julia> y = state[1]; 
240+ 
241+ julia> for xt in x 
242+            y, state = cell(g, xt, state) 
243+        end 
244+ 
245+ julia> size(y) # (d_out, num_nodes) 
246+ (3, 5) 
247+ ``` 
248+ """ 
249+ @concrete  struct  GConvLSTMCell <:  GNNLayer 
250+     conv_x_i
251+     conv_h_i
252+     w_i
253+     b_i
254+     conv_x_f
255+     conv_h_f
256+     w_f
257+     b_f
258+     conv_x_c
259+     conv_h_c
260+     w_c
261+     b_c
262+     conv_x_o
263+     conv_h_o
264+     w_o
265+     b_o
266+     k:: Int 
267+     in:: Int 
268+     out:: Int 
269+ end 
270+ 
271+ Flux. @layer  GConvLSTMCell
272+ 
273+ function  GConvLSTMCell (ch:: Pair{Int, Int} , k:: Int ;
274+                         bias:: Bool  =  true ,
275+                         init =  Flux. glorot_uniform)
276+     in, out =  ch
277+     #  input gate
278+     conv_x_i =  ChebConv (in =>  out, k; bias, init)
279+     conv_h_i =  ChebConv (out =>  out, k; bias, init)
280+     w_i =  init (out, 1 )
281+     b_i =  bias ?  Flux. create_bias (w_i, true , out) :  false 
282+     #  forget gate
283+     conv_x_f =  ChebConv (in =>  out, k; bias, init)
284+     conv_h_f =  ChebConv (out =>  out, k; bias, init)
285+     w_f =  init (out, 1 )
286+     b_f =  bias ?  Flux. create_bias (w_f, true , out) :  false 
287+     #  cell state
288+     conv_x_c =  ChebConv (in =>  out, k; bias, init)
289+     conv_h_c =  ChebConv (out =>  out, k; bias, init)
290+     w_c =  init (out, 1 )
291+     b_c =  bias ?  Flux. create_bias (w_c, true , out) :  false 
292+     #  output gate
293+     conv_x_o =  ChebConv (in =>  out, k; bias, init)
294+     conv_h_o =  ChebConv (out =>  out, k; bias, init)
295+     w_o =  init (out, 1 )
296+     b_o =  bias ?  Flux. create_bias (w_o, true , out) :  false 
297+     return  GConvLSTMCell (conv_x_i, conv_h_i, w_i, b_i,
298+                          conv_x_f, conv_h_f, w_f, b_f,
299+                          conv_x_c, conv_h_c, w_c, b_c,
300+                          conv_x_o, conv_h_o, w_o, b_o,
301+                          k, in, out)
302+ end 
303+ 
304+ function  Flux. initialstates (cell:: GConvLSTMCell )
305+     (zeros_like (cell. conv_x_i. weight, cell. out), zeros_like (cell. conv_x_i. weight, cell. out))
306+ end 
307+ 
308+ function  (cell:: GConvLSTMCell )(g:: GNNGraph , x:: AbstractMatrix , (h, c))
309+     if  h isa  AbstractVector
310+         h =  repeat (h, 1 , g. num_nodes)
311+     end 
312+     if  c isa  AbstractVector
313+         c =  repeat (c, 1 , g. num_nodes)
314+     end 
315+     @assert  ndims (h) ==  2  &&  ndims (c) ==  2 
316+     #  input gate
317+     i =  cell. conv_x_i (g, x) .+  cell. conv_h_i (g, h) .+  cell. w_i .*  c .+  cell. b_i 
318+     i =  Flux. sigmoid_fast (i)
319+     #  forget gate
320+     f =  cell. conv_x_f (g, x) .+  cell. conv_h_f (g, h) .+  cell. w_f .*  c .+  cell. b_f
321+     f =  Flux. sigmoid_fast (f)
322+     #  cell state
323+     c =  f .*  c .+  i .*  Flux. tanh_fast (cell. conv_x_c (g, x) .+  cell. conv_h_c (g, h) .+  cell. w_c .*  c .+  cell. b_c)
324+     #  output gate
325+     o =  cell. conv_x_o (g, x) .+  cell. conv_h_o (g, h) .+  cell. w_o .*  c .+  cell. b_o
326+     o =  Flux. sigmoid_fast (o)
327+     h =   o .*  Flux. tanh_fast (c)
328+     return  h, (h, c)
329+ end 
330+ 
331+ function  Base. show (io:: IO , cell:: GConvLSTMCell )
332+     print (io, " GConvLSTMCell($(cell. in)  => $(cell. out) , $(cell. k) )" 
333+ end 
334+ 
335+ 
336+ """ 
337+     GConvLSTM(in => out, k; kws...) 
338+ 
339+ The recurrent layer corresponding to the [`GConvLSTMCell`](@ref) cell,  
340+ used to process an entire temporal sequence of node features at once. 
341+ 
342+ The arguments are the same as for [`GConvLSTMCell`](@ref). 
343+ 
344+ # Forward  
345+ 
346+     layer(g::GNNGraph, x, [state]) 
347+ 
348+ - `g`: The input graph. 
349+ - `x`: The time-varying node features. It should be an array of size `in x timesteps x num_nodes`. 
350+ - `state`: The initial hidden state of the LSTM cell.  
351+       If given, it is a tuple `(h, c)` where both elements are matrices of size `out x num_nodes`. 
352+       If not provided, the initial hidden state is assumed to be a tuple of matrices of zeros. 
353+ 
354+ Applies the recurrent cell to each timestep of the input sequence and returns the output as 
355+ an array of size `out x timesteps x num_nodes`. 
356+ 
357+ # Examples 
358+ 
359+ ```jldoctest 
360+ julia> num_nodes, num_edges = 5, 10; 
361+ 
362+ julia> d_in, d_out = 2, 3; 
363+ 
364+ julia> timesteps = 5; 
365+ 
366+ julia> g = rand_graph(num_nodes, num_edges); 
367+ 
368+ julia> x = rand(Float32, d_in, timesteps, num_nodes); 
369+ 
370+ julia> layer = GConvLSTM(d_in => d_out, 2); 
371+ 
372+ julia> y = layer(g, x); 
373+ 
374+ julia> size(y) # (d_out, timesteps, num_nodes) 
375+ (3, 5, 5) 
376+ ``` 
377+ """ 
378+ struct  GConvLSTM{G<: GConvLSTMCell } <:  GNNLayer 
379+     cell:: G 
380+ end 
381+ 
382+ Flux. @layer  GConvLSTM
383+ 
384+ function  GConvLSTM (ch:: Pair{Int,Int} , k:: Int ; kws... )
385+     return  GConvLSTM (GConvLSTMCell (ch, k; kws... ))
386+ end 
387+ 
388+ Flux. initialstates (rnn:: GConvLSTM ) =  Flux. initialstates (rnn. cell)
389+ 
390+ function  (rnn:: GConvLSTM )(g:: GNNGraph , x:: AbstractArray )
391+     return  scan (rnn. cell, g, x, initialstates (rnn))
392+ end 
393+ 
394+ function  Base. show (io:: IO , rnn:: GConvLSTM )
395+     print (io, " GConvLSTM($(rnn. cell. in)  => $(rnn. cell. out) , $(rnn. cell. k) )" 
396+ end 
397+ 
0 commit comments