1- #  Adapting Flux.Recur to work with GNNGraphs
2- function  (m:: Flux.Recur )(g:: GNNGraph , x)
3-     m. state, y =  m. cell (m. state, g, x)
4-     return  y
5- end 
1+ #  #  Adapting Flux.Recur to work with GNNGraphs
2+ #   function (m::Flux.Recur)(g::GNNGraph, x)
3+ #       m.state, y = m.cell(m.state, g, x)
4+ #       return y
5+ #   end
66
7- function  (m:: Flux.Recur )(g:: GNNGraph , x:: AbstractArray{T, 3} ) where  T
8-     h =  [m (g, x_t) for  x_t in  Flux. eachlastdim (x)]
9-     sze =  size (h[1 ])
10-     reshape (reduce (hcat, h), sze[1 ], sze[2 ], length (h))
11- end 
7+ #   function (m::Flux.Recur)(g::GNNGraph, x::AbstractArray{T, 3}) where T
8+ #       h = [m(g, x_t) for x_t in Flux.eachlastdim(x)]
9+ #       sze = size(h[1])
10+ #       reshape(reduce(hcat, h), sze[1], sze[2], length(h))
11+ #   end
1212
1313struct  TGCNCell <:  GNNLayer 
1414    conv:: GCNConv 
1515    gru:: Flux.GRUv3Cell 
16-     state0
1716    in:: Int 
1817    out:: Int 
1918end 
@@ -23,29 +22,26 @@ Flux.@layer TGCNCell
2322function  TGCNCell (ch:: Pair{Int, Int} ;
2423                  bias:: Bool  =  true ,
2524                  init =  Flux. glorot_uniform,
26-                   init_state =  Flux. zeros32,
2725                  add_self_loops =  false ,
2826                  use_edge_weight =  true )
2927    in, out =  ch
30-     conv =  GCNConv (in =>  out, sigmoid; init, bias, add_self_loops,
31-                    use_edge_weight)
32-     gru =  Flux. GRUv3Cell (out, out)
33-     state0 =  init_state (out,1 )
34-     return  TGCNCell (conv, gru, state0, in,out)
28+     conv =  GCNConv (in =>  out, sigmoid; init, bias, add_self_loops, use_edge_weight)
29+     gru =  Flux. GRUCell (out =>  out)
30+     return  TGCNCell (conv, gru, in, out)
3531end 
3632
37- function  (tgcn:: TGCNCell )(h,  g:: GNNGraph , x:: AbstractArray )
38-     x̃  =  tgcn. conv (g, x)
39-     h, x̃  =  tgcn. gru (h, x̃ )
40-     return  h, x̃ 
33+ function  (tgcn:: TGCNCell )(g:: GNNGraph , x:: AbstractVecOrMat , h :: AbstractVecOrMat )
34+     x  =  tgcn. conv (g, x)
35+     x, h  =  tgcn. gru (x, h )
36+     return  x, h 
4137end 
4238
4339function  Base. show (io:: IO , tgcn:: TGCNCell )
4440    print (io, " TGCNCell($(tgcn. in)  => $(tgcn. out) )" 
4541end 
4642
4743""" 
48-     TGCN(in => out; [bias, init, init_state,  add_self_loops, use_edge_weight]) 
44+     TGCN(in => out; [bias, init, add_self_loops, use_edge_weight]) 
4945
5046Temporal Graph Convolutional Network (T-GCN) recurrent layer from the paper [T-GCN: A Temporal Graph Convolutional Network for Traffic Prediction](https://arxiv.org/pdf/1811.05320.pdf). 
5147
@@ -57,12 +53,20 @@ Performs a layer of GCNConv to model spatial dependencies, followed by a Gated R
5753- `out`: Number of output features. 
5854- `bias`: Add learnable bias. Default `true`. 
5955- `init`: Weights' initializer. Default `glorot_uniform`. 
60- - `init_state`: Initial state of the hidden stat of the GRU layer. Default `zeros32`. 
6156- `add_self_loops`: Add self loops to the graph before performing the convolution. Default `false`. 
6257- `use_edge_weight`: If `true`, consider the edge weights in the input graph (if available). 
6358                     If `add_self_loops=true` the new weights will be set to 1.  
6459                     This option is ignored if the `edge_weight` is explicitly provided in the forward pass. 
6560                     Default `false`. 
61+ 
62+ # Forward  
63+ 
64+     tgcn(g::GNNGraph, x, [h]) 
65+ 
66+ - `x`: The input to the TGCN. It should be a matrix size `in x timesteps` or an array of size `in x timesteps x num_nodes`. 
67+ - `h`: The initial hidden state of the GRU cell. If given, it is a vector of size `out` or a matrix of size `out x num_nodes`. 
68+        If not provided, it is assumed to be a vector of zeros. 
69+ 
6670# Examples 
6771
6872```jldoctest 
@@ -78,30 +82,43 @@ Recur(
7882)         # Total: 8 trainable arrays, 264 parameters, 
7983          # plus 1 non-trainable, 6 parameters, summarysize 1.492 KiB. 
8084
81- julia> g, x = rand_graph(5, 10), rand(Float32, 2, 5); 
85+ julia> g = rand_graph(5, 10); 
86+ 
87+ julia> x = rand(Float32, 2, 5); 
8288
8389julia> y = tgcn(g, x); 
8490
8591julia> size(y) 
8692(6, 5) 
8793
88- julia> Flux.reset!(tgcn); 
89- 
90- julia> tgcn(rand_graph(5, 10), rand(Float32, 2, 5, 20)) |> size # batch size of 20 
94+ julia> tgcn(g, rand(Float32, 2, 5, 20)) |> size # batch size of 20 
9195(6, 5, 20) 
9296``` 
93- 
94- !!! warning "Batch size changes" 
95-     Failing to call `reset!` when the input batch size changes can lead to unexpected behavior. 
9697""" 
97- TGCN (ch; kwargs... ) =  Flux. Recur (TGCNCell (ch; kwargs... ))
98+ struct  TGCN
99+     tgcn:: TGCNCell 
100+ end 
101+ 
102+ Flux. @layer  TGCN
103+ 
104+ TGCN (ch:: Pair{Int, Int} ; kws... ) =  TGCN (TGCNCell (ch; kws... ))
98105
99- Flux. Recur (tgcn:: TGCNCell ) =  Flux. Recur (tgcn, tgcn. state0)
106+ function  (tgcn:: TGCN )(g:: GNNGraph , x:: AbstractArray , h)
107+     for  i in  1 : size (x, 2 )
108+         x, h =  tgcn. tgcn (g, x[:, i], h)
109+     end 
110+     return  x
111+ end 
112+     
100113
101- #  make TGCN compatible with GNNChain
102- (l:: Flux.Recur{TGCNCell} )(g:: GNNGraph ) =  GNNGraph (g, ndata =  l (g, node_features (g)))
103- _applylayer (l:: Flux.Recur{TGCNCell} , g:: GNNGraph , x) =  l (g, x)
104- _applylayer (l:: Flux.Recur{TGCNCell} , g:: GNNGraph ) =  l (g)
114+ #  TGCN(ch; kwargs...) = Flux.Recur(TGCNCell(ch; kwargs...))
115+ 
116+ #  Flux.Recur(tgcn::TGCNCell) = Flux.Recur(tgcn, tgcn.state0)
117+ 
118+ #  # make TGCN compatible with GNNChain
119+ #  (l::Flux.Recur{TGCNCell})(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g)))
120+ #  _applylayer(l::Flux.Recur{TGCNCell}, g::GNNGraph, x) = l(g, x)
121+ #  _applylayer(l::Flux.Recur{TGCNCell}, g::GNNGraph) = l(g)
105122
106123
107124""" 
@@ -149,7 +166,7 @@ julia> size(y)
149166    Failing to call `reset!` when the input batch size changes can lead to unexpected behavior. 
150167""" 
151168struct  A3TGCN <:  GNNLayer 
152-     tgcn:: Flux.Recur{TGCNCell} 
169+     tgcn:: TGCN 
153170    dense1:: Dense 
154171    dense2:: Dense 
155172    in:: Int 
@@ -272,12 +289,12 @@ julia> size(z)
272289(5, 5, 30) 
273290``` 
274291""" 
275- GConvGRU (ch, k, n; kwargs... ) =  Flux. Recur (GConvGRUCell (ch, k, n; kwargs... ))
276- Flux. Recur (ggru:: GConvGRUCell ) =  Flux. Recur (ggru, ggru. state0)
292+ #   GConvGRU(ch, k, n; kwargs...) = Flux.Recur(GConvGRUCell(ch, k, n; kwargs...))
293+ #   Flux.Recur(ggru::GConvGRUCell) = Flux.Recur(ggru, ggru.state0)
277294
278- (l:: Flux.Recur{GConvGRUCell} )(g:: GNNGraph ) =  GNNGraph (g, ndata =  l (g, node_features (g)))
279- _applylayer (l:: Flux.Recur{GConvGRUCell} , g:: GNNGraph , x) =  l (g, x)
280- _applylayer (l:: Flux.Recur{GConvGRUCell} , g:: GNNGraph ) =  l (g)
295+ #   (l::Flux.Recur{GConvGRUCell})(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g)))
296+ #   _applylayer(l::Flux.Recur{GConvGRUCell}, g::GNNGraph, x) = l(g, x)
297+ #   _applylayer(l::Flux.Recur{GConvGRUCell}, g::GNNGraph) = l(g)
281298
282299struct  GConvLSTMCell <:  GNNLayer 
283300    conv_x_i:: ChebConv 
@@ -394,12 +411,12 @@ julia> size(z)
394411(5, 5, 30) 
395412``` 
396413""" 
397- GConvLSTM (ch, k, n; kwargs... ) =  Flux. Recur (GConvLSTMCell (ch, k, n; kwargs... ))
398- Flux. Recur (tgcn:: GConvLSTMCell ) =  Flux. Recur (tgcn, tgcn. state0)
414+ #   GConvLSTM(ch, k, n; kwargs...) = Flux.Recur(GConvLSTMCell(ch, k, n; kwargs...))
415+ #   Flux.Recur(tgcn::GConvLSTMCell) = Flux.Recur(tgcn, tgcn.state0)
399416
400- (l:: Flux.Recur{GConvLSTMCell} )(g:: GNNGraph ) =  GNNGraph (g, ndata =  l (g, node_features (g)))
401- _applylayer (l:: Flux.Recur{GConvLSTMCell} , g:: GNNGraph , x) =  l (g, x)
402- _applylayer (l:: Flux.Recur{GConvLSTMCell} , g:: GNNGraph ) =  l (g)
417+ #   (l::Flux.Recur{GConvLSTMCell})(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g)))
418+ #   _applylayer(l::Flux.Recur{GConvLSTMCell}, g::GNNGraph, x) = l(g, x)
419+ #   _applylayer(l::Flux.Recur{GConvLSTMCell}, g::GNNGraph) = l(g)
403420
404421struct  DCGRUCell
405422    in:: Int 
@@ -477,12 +494,12 @@ julia> size(z)
477494(5, 5, 30) 
478495``` 
479496""" 
480- DCGRU (ch, k, n; kwargs... ) =  Flux. Recur (DCGRUCell (ch, k, n; kwargs... ))
481- Flux. Recur (dcgru:: DCGRUCell ) =  Flux. Recur (dcgru, dcgru. state0)
497+ #   DCGRU(ch, k, n; kwargs...) = Flux.Recur(DCGRUCell(ch, k, n; kwargs...))
498+ #   Flux.Recur(dcgru::DCGRUCell) = Flux.Recur(dcgru, dcgru.state0)
482499
483- (l:: Flux.Recur{DCGRUCell} )(g:: GNNGraph ) =  GNNGraph (g, ndata =  l (g, node_features (g)))
484- _applylayer (l:: Flux.Recur{DCGRUCell} , g:: GNNGraph , x) =  l (g, x)
485- _applylayer (l:: Flux.Recur{DCGRUCell} , g:: GNNGraph ) =  l (g)
500+ #   (l::Flux.Recur{DCGRUCell})(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g)))
501+ #   _applylayer(l::Flux.Recur{DCGRUCell}, g::GNNGraph, x) = l(g, x)
502+ #   _applylayer(l::Flux.Recur{DCGRUCell}, g::GNNGraph) = l(g)
486503
487504""" 
488505    EvolveGCNO(ch; bias = true, init = glorot_uniform, init_state = Flux.zeros32) 
@@ -539,8 +556,6 @@ struct EvolveGCNO
539556    Bc
540557end 
541558
542- Flux. @functor  EvolveGCNO
543- 
544559function  EvolveGCNO (ch; bias =  true , init =  glorot_uniform, init_state =  Flux. zeros32)
545560    in, out =  ch
546561    W =  init (out, in)
580595function  Base. show (io:: IO , egcno:: EvolveGCNO )
581596    print (io, " EvolveGCNO($(egcno. in)  => $(egcno. out) )" 
582597end 
583- 
584- function  (l:: GINConv )(tg:: TemporalSnapshotsGNNGraph , x:: AbstractVector )
585-     return  l .(tg. snapshots, x)
586- end 
587- 
588- function  (l:: ChebConv )(tg:: TemporalSnapshotsGNNGraph , x:: AbstractVector )
589-     return  l .(tg. snapshots, x)
590- end 
591- 
592- function  (l:: GATConv )(tg:: TemporalSnapshotsGNNGraph , x:: AbstractVector )
593-     return  l .(tg. snapshots, x)
594- end 
595- 
596- function  (l:: GATv2Conv )(tg:: TemporalSnapshotsGNNGraph , x:: AbstractVector )
597-     return  l .(tg. snapshots, x)
598- end 
599- 
600- function  (l:: GatedGraphConv )(tg:: TemporalSnapshotsGNNGraph , x:: AbstractVector )
601-     return  l .(tg. snapshots, x)
602- end 
603- 
604- function  (l:: CGConv )(tg:: TemporalSnapshotsGNNGraph , x:: AbstractVector )
605-     return  l .(tg. snapshots, x)
606- end 
607- 
608- function  (l:: SGConv )(tg:: TemporalSnapshotsGNNGraph , x:: AbstractVector )
609-     return  l .(tg. snapshots, x)
610- end 
611- 
612- function  (l:: TransformerConv )(tg:: TemporalSnapshotsGNNGraph , x:: AbstractVector )
613-     return  l .(tg. snapshots, x)
614- end 
615- 
616- function  (l:: GCNConv )(tg:: TemporalSnapshotsGNNGraph , x:: AbstractVector )
617-     return  l .(tg. snapshots, x)
618- end 
619- 
620- function  (l:: ResGatedGraphConv )(tg:: TemporalSnapshotsGNNGraph , x:: AbstractVector )
621-     return  l .(tg. snapshots, x)
622- end 
623- 
624- function  (l:: SAGEConv )(tg:: TemporalSnapshotsGNNGraph , x:: AbstractVector )
625-     return  l .(tg. snapshots, x)
626- end 
627- 
628- function  (l:: GraphConv )(tg:: TemporalSnapshotsGNNGraph , x:: AbstractVector )
629-     return  l .(tg. snapshots, x)
630- end 
0 commit comments