1- # Adapting Flux.Recur to work with GNNGraphs
2- function (m:: Flux.Recur )(g:: GNNGraph , x)
3- m. state, y = m. cell (m. state, g, x)
4- return y
5- end
1+ # # Adapting Flux.Recur to work with GNNGraphs
2+ # function (m::Flux.Recur)(g::GNNGraph, x)
3+ # m.state, y = m.cell(m.state, g, x)
4+ # return y
5+ # end
66
7- function (m:: Flux.Recur )(g:: GNNGraph , x:: AbstractArray{T, 3} ) where T
8- h = [m (g, x_t) for x_t in Flux. eachlastdim (x)]
9- sze = size (h[1 ])
10- reshape (reduce (hcat, h), sze[1 ], sze[2 ], length (h))
11- end
7+ # function (m::Flux.Recur)(g::GNNGraph, x::AbstractArray{T, 3}) where T
8+ # h = [m(g, x_t) for x_t in Flux.eachlastdim(x)]
9+ # sze = size(h[1])
10+ # reshape(reduce(hcat, h), sze[1], sze[2], length(h))
11+ # end
1212
1313struct TGCNCell <: GNNLayer
1414 conv:: GCNConv
1515 gru:: Flux.GRUv3Cell
16- state0
1716 in:: Int
1817 out:: Int
1918end
@@ -23,29 +22,26 @@ Flux.@layer TGCNCell
2322function TGCNCell (ch:: Pair{Int, Int} ;
2423 bias:: Bool = true ,
2524 init = Flux. glorot_uniform,
26- init_state = Flux. zeros32,
2725 add_self_loops = false ,
2826 use_edge_weight = true )
2927 in, out = ch
30- conv = GCNConv (in => out, sigmoid; init, bias, add_self_loops,
31- use_edge_weight)
32- gru = Flux. GRUv3Cell (out, out)
33- state0 = init_state (out,1 )
34- return TGCNCell (conv, gru, state0, in,out)
28+ conv = GCNConv (in => out, sigmoid; init, bias, add_self_loops, use_edge_weight)
29+ gru = Flux. GRUCell (out => out)
30+ return TGCNCell (conv, gru, in, out)
3531end
3632
37- function (tgcn:: TGCNCell )(h, g:: GNNGraph , x:: AbstractArray )
38- x̃ = tgcn. conv (g, x)
39- h, x̃ = tgcn. gru (h, x̃ )
40- return h, x̃
33+ function (tgcn:: TGCNCell )(g:: GNNGraph , x:: AbstractVecOrMat , h :: AbstractVecOrMat )
34+ x = tgcn. conv (g, x)
35+ x, h = tgcn. gru (x, h )
36+ return x, h
4137end
4238
4339function Base. show (io:: IO , tgcn:: TGCNCell )
4440 print (io, " TGCNCell($(tgcn. in) => $(tgcn. out) )" )
4541end
4642
4743"""
48- TGCN(in => out; [bias, init, init_state, add_self_loops, use_edge_weight])
44+ TGCN(in => out; [bias, init, add_self_loops, use_edge_weight])
4945
5046Temporal Graph Convolutional Network (T-GCN) recurrent layer from the paper [T-GCN: A Temporal Graph Convolutional Network for Traffic Prediction](https://arxiv.org/pdf/1811.05320.pdf).
5147
@@ -57,12 +53,20 @@ Performs a layer of GCNConv to model spatial dependencies, followed by a Gated R
5753- `out`: Number of output features.
5854- `bias`: Add learnable bias. Default `true`.
5955- `init`: Weights' initializer. Default `glorot_uniform`.
60- - `init_state`: Initial state of the hidden stat of the GRU layer. Default `zeros32`.
6156- `add_self_loops`: Add self loops to the graph before performing the convolution. Default `false`.
6257- `use_edge_weight`: If `true`, consider the edge weights in the input graph (if available).
6358 If `add_self_loops=true` the new weights will be set to 1.
6459 This option is ignored if the `edge_weight` is explicitly provided in the forward pass.
6560 Default `false`.
61+
62+ # Forward
63+
64+ tgcn(g::GNNGraph, x, [h])
65+
66+ - `x`: The input to the TGCN. It should be a matrix size `in x timesteps` or an array of size `in x timesteps x num_nodes`.
67+ - `h`: The initial hidden state of the GRU cell. If given, it is a vector of size `out` or a matrix of size `out x num_nodes`.
68+ If not provided, it is assumed to be a vector of zeros.
69+
6670# Examples
6771
6872```jldoctest
@@ -78,30 +82,43 @@ Recur(
7882) # Total: 8 trainable arrays, 264 parameters,
7983 # plus 1 non-trainable, 6 parameters, summarysize 1.492 KiB.
8084
81- julia> g, x = rand_graph(5, 10), rand(Float32, 2, 5);
85+ julia> g = rand_graph(5, 10);
86+
87+ julia> x = rand(Float32, 2, 5);
8288
8389julia> y = tgcn(g, x);
8490
8591julia> size(y)
8692(6, 5)
8793
88- julia> Flux.reset!(tgcn);
89-
90- julia> tgcn(rand_graph(5, 10), rand(Float32, 2, 5, 20)) |> size # batch size of 20
94+ julia> tgcn(g, rand(Float32, 2, 5, 20)) |> size # batch size of 20
9195(6, 5, 20)
9296```
93-
94- !!! warning "Batch size changes"
95- Failing to call `reset!` when the input batch size changes can lead to unexpected behavior.
9697"""
97- TGCN (ch; kwargs... ) = Flux. Recur (TGCNCell (ch; kwargs... ))
98+ struct TGCN
99+ tgcn:: TGCNCell
100+ end
101+
102+ Flux. @layer TGCN
103+
104+ TGCN (ch:: Pair{Int, Int} ; kws... ) = TGCN (TGCNCell (ch; kws... ))
98105
99- Flux. Recur (tgcn:: TGCNCell ) = Flux. Recur (tgcn, tgcn. state0)
106+ function (tgcn:: TGCN )(g:: GNNGraph , x:: AbstractArray , h)
107+ for i in 1 : size (x, 2 )
108+ x, h = tgcn. tgcn (g, x[:, i], h)
109+ end
110+ return x
111+ end
112+
100113
101- # make TGCN compatible with GNNChain
102- (l:: Flux.Recur{TGCNCell} )(g:: GNNGraph ) = GNNGraph (g, ndata = l (g, node_features (g)))
103- _applylayer (l:: Flux.Recur{TGCNCell} , g:: GNNGraph , x) = l (g, x)
104- _applylayer (l:: Flux.Recur{TGCNCell} , g:: GNNGraph ) = l (g)
114+ # TGCN(ch; kwargs...) = Flux.Recur(TGCNCell(ch; kwargs...))
115+
116+ # Flux.Recur(tgcn::TGCNCell) = Flux.Recur(tgcn, tgcn.state0)
117+
118+ # # make TGCN compatible with GNNChain
119+ # (l::Flux.Recur{TGCNCell})(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g)))
120+ # _applylayer(l::Flux.Recur{TGCNCell}, g::GNNGraph, x) = l(g, x)
121+ # _applylayer(l::Flux.Recur{TGCNCell}, g::GNNGraph) = l(g)
105122
106123
107124"""
@@ -149,7 +166,7 @@ julia> size(y)
149166 Failing to call `reset!` when the input batch size changes can lead to unexpected behavior.
150167"""
151168struct A3TGCN <: GNNLayer
152- tgcn:: Flux.Recur{TGCNCell}
169+ tgcn:: TGCN
153170 dense1:: Dense
154171 dense2:: Dense
155172 in:: Int
@@ -272,12 +289,12 @@ julia> size(z)
272289(5, 5, 30)
273290```
274291"""
275- GConvGRU (ch, k, n; kwargs... ) = Flux. Recur (GConvGRUCell (ch, k, n; kwargs... ))
276- Flux. Recur (ggru:: GConvGRUCell ) = Flux. Recur (ggru, ggru. state0)
292+ # GConvGRU(ch, k, n; kwargs...) = Flux.Recur(GConvGRUCell(ch, k, n; kwargs...))
293+ # Flux.Recur(ggru::GConvGRUCell) = Flux.Recur(ggru, ggru.state0)
277294
278- (l:: Flux.Recur{GConvGRUCell} )(g:: GNNGraph ) = GNNGraph (g, ndata = l (g, node_features (g)))
279- _applylayer (l:: Flux.Recur{GConvGRUCell} , g:: GNNGraph , x) = l (g, x)
280- _applylayer (l:: Flux.Recur{GConvGRUCell} , g:: GNNGraph ) = l (g)
295+ # (l::Flux.Recur{GConvGRUCell})(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g)))
296+ # _applylayer(l::Flux.Recur{GConvGRUCell}, g::GNNGraph, x) = l(g, x)
297+ # _applylayer(l::Flux.Recur{GConvGRUCell}, g::GNNGraph) = l(g)
281298
282299struct GConvLSTMCell <: GNNLayer
283300 conv_x_i:: ChebConv
@@ -394,12 +411,12 @@ julia> size(z)
394411(5, 5, 30)
395412```
396413"""
397- GConvLSTM (ch, k, n; kwargs... ) = Flux. Recur (GConvLSTMCell (ch, k, n; kwargs... ))
398- Flux. Recur (tgcn:: GConvLSTMCell ) = Flux. Recur (tgcn, tgcn. state0)
414+ # GConvLSTM(ch, k, n; kwargs...) = Flux.Recur(GConvLSTMCell(ch, k, n; kwargs...))
415+ # Flux.Recur(tgcn::GConvLSTMCell) = Flux.Recur(tgcn, tgcn.state0)
399416
400- (l:: Flux.Recur{GConvLSTMCell} )(g:: GNNGraph ) = GNNGraph (g, ndata = l (g, node_features (g)))
401- _applylayer (l:: Flux.Recur{GConvLSTMCell} , g:: GNNGraph , x) = l (g, x)
402- _applylayer (l:: Flux.Recur{GConvLSTMCell} , g:: GNNGraph ) = l (g)
417+ # (l::Flux.Recur{GConvLSTMCell})(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g)))
418+ # _applylayer(l::Flux.Recur{GConvLSTMCell}, g::GNNGraph, x) = l(g, x)
419+ # _applylayer(l::Flux.Recur{GConvLSTMCell}, g::GNNGraph) = l(g)
403420
404421struct DCGRUCell
405422 in:: Int
@@ -477,12 +494,12 @@ julia> size(z)
477494(5, 5, 30)
478495```
479496"""
480- DCGRU (ch, k, n; kwargs... ) = Flux. Recur (DCGRUCell (ch, k, n; kwargs... ))
481- Flux. Recur (dcgru:: DCGRUCell ) = Flux. Recur (dcgru, dcgru. state0)
497+ # DCGRU(ch, k, n; kwargs...) = Flux.Recur(DCGRUCell(ch, k, n; kwargs...))
498+ # Flux.Recur(dcgru::DCGRUCell) = Flux.Recur(dcgru, dcgru.state0)
482499
483- (l:: Flux.Recur{DCGRUCell} )(g:: GNNGraph ) = GNNGraph (g, ndata = l (g, node_features (g)))
484- _applylayer (l:: Flux.Recur{DCGRUCell} , g:: GNNGraph , x) = l (g, x)
485- _applylayer (l:: Flux.Recur{DCGRUCell} , g:: GNNGraph ) = l (g)
500+ # (l::Flux.Recur{DCGRUCell})(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g)))
501+ # _applylayer(l::Flux.Recur{DCGRUCell}, g::GNNGraph, x) = l(g, x)
502+ # _applylayer(l::Flux.Recur{DCGRUCell}, g::GNNGraph) = l(g)
486503
487504"""
488505 EvolveGCNO(ch; bias = true, init = glorot_uniform, init_state = Flux.zeros32)
@@ -539,8 +556,6 @@ struct EvolveGCNO
539556 Bc
540557end
541558
542- Flux. @functor EvolveGCNO
543-
544559function EvolveGCNO (ch; bias = true , init = glorot_uniform, init_state = Flux. zeros32)
545560 in, out = ch
546561 W = init (out, in)
580595function Base. show (io:: IO , egcno:: EvolveGCNO )
581596 print (io, " EvolveGCNO($(egcno. in) => $(egcno. out) )" )
582597end
583-
584- function (l:: GINConv )(tg:: TemporalSnapshotsGNNGraph , x:: AbstractVector )
585- return l .(tg. snapshots, x)
586- end
587-
588- function (l:: ChebConv )(tg:: TemporalSnapshotsGNNGraph , x:: AbstractVector )
589- return l .(tg. snapshots, x)
590- end
591-
592- function (l:: GATConv )(tg:: TemporalSnapshotsGNNGraph , x:: AbstractVector )
593- return l .(tg. snapshots, x)
594- end
595-
596- function (l:: GATv2Conv )(tg:: TemporalSnapshotsGNNGraph , x:: AbstractVector )
597- return l .(tg. snapshots, x)
598- end
599-
600- function (l:: GatedGraphConv )(tg:: TemporalSnapshotsGNNGraph , x:: AbstractVector )
601- return l .(tg. snapshots, x)
602- end
603-
604- function (l:: CGConv )(tg:: TemporalSnapshotsGNNGraph , x:: AbstractVector )
605- return l .(tg. snapshots, x)
606- end
607-
608- function (l:: SGConv )(tg:: TemporalSnapshotsGNNGraph , x:: AbstractVector )
609- return l .(tg. snapshots, x)
610- end
611-
612- function (l:: TransformerConv )(tg:: TemporalSnapshotsGNNGraph , x:: AbstractVector )
613- return l .(tg. snapshots, x)
614- end
615-
616- function (l:: GCNConv )(tg:: TemporalSnapshotsGNNGraph , x:: AbstractVector )
617- return l .(tg. snapshots, x)
618- end
619-
620- function (l:: ResGatedGraphConv )(tg:: TemporalSnapshotsGNNGraph , x:: AbstractVector )
621- return l .(tg. snapshots, x)
622- end
623-
624- function (l:: SAGEConv )(tg:: TemporalSnapshotsGNNGraph , x:: AbstractVector )
625- return l .(tg. snapshots, x)
626- end
627-
628- function (l:: GraphConv )(tg:: TemporalSnapshotsGNNGraph , x:: AbstractVector )
629- return l .(tg. snapshots, x)
630- end
0 commit comments