11function scan (cell, g:: GNNGraph , x:: AbstractArray{T,3} , state) where {T}
22 y = []
3- for x_t in eachslice (x, dims = 2 )
4- yt, state = cell (g, x_t , state)
3+ for xt in eachslice (x, dims = 2 )
4+ yt, state = cell (g, xt , state)
55 y = vcat (y, [yt])
66 end
77 return stack (y, dims = 2 )
88end
99
10+ function scan (cell, tg:: TemporalSnapshotsGNNGraph , x:: AbstractVector , state)
11+ # @assert length(x) == length(tg.snapshots)
12+ y = []
13+ for (t, xt) in enumerate (x)
14+ gt = tg. snapshots[t]
15+ yt, state = cell (gt, xt, state)
16+ y = vcat (y, [yt])
17+ end
18+ return y
19+ end
20+
1021
1122"""
1223 GNNRecurrence(cell)
@@ -20,7 +31,7 @@ to process an entire temporal sequence of node features at once.
2031
2132- `g`: The input graph.
2233- `x`: The time-varying node features. An array of size `in x timesteps x num_nodes`.
23- - `state`: The initial state of the cell.
34+ - `state`: The current state of the cell.
2435 If not provided, it is generated by calling `Flux.initialstates(cell)`.
2536
2637Applies the recurrent cell to each timestep of the input sequence and returns the output as
@@ -61,11 +72,11 @@ Flux.@layer GNNRecurrence
6172
6273Flux. initialstates (rnn:: GNNRecurrence ) = Flux. initialstates (rnn. cell)
6374
64- function (rnn:: GNNRecurrence )(g:: GNNGraph , x:: AbstractArray{T,3} ) where {T}
75+ function (rnn:: GNNRecurrence )(g, x)
6576 return rnn (g, x, initialstates (rnn))
6677end
6778
68- function (rnn:: GNNRecurrence )(g:: GNNGraph , x:: AbstractArray{T,3} , state) where {T}
79+ function (rnn:: GNNRecurrence )(g, x, state) where {T}
6980 return scan (rnn. cell, g, x, state)
7081end
7182
@@ -97,7 +108,7 @@ followed by a Gated Recurrent Unit (GRU) cell to model temporal dependencies.
97108
98109- `g`: The input graph.
99110- `x`: The node features. It should be a matrix of size `in x num_nodes`.
100- - `h`: The initial hidden state of the GRU cell. If given, it is a matrix of size `out x num_nodes`.
111+ - `h`: The current hidden state of the GRU cell. If given, it is a matrix of size `out x num_nodes`.
101112 If not provided, it is assumed to be a matrix of zeros.
102113
103114Performs one recurrence step and returns a tuple `(h, h)`,
@@ -251,9 +262,9 @@ followed by a Long Short-Term Memory (LSTM) cell to model temporal dependencies.
251262
252263- `g`: The input graph.
253264- `x`: The node features. It should be a matrix of size `in x num_nodes`.
254- - `state`: The initial hidden state of the LSTM cell.
265+ - `state`: The current state of the LSTM cell.
255266 If given, it is a tuple `(h, c)` where both `h` and `c` are arrays of size `out x num_nodes`.
256- If not provided, the initial hidden state is assumed to be a tuple of matrices of zeros.
267+ If not provided, it is assumed to be a tuple of matrices of zeros.
257268
258269Performs one recurrence step and returns a tuple `(output, state)`,
259270where `output` is the updated hidden state `h` of the LSTM cell and `state` is the updated tuple `(h, c)`.
@@ -434,7 +445,7 @@ in combination with a Gated Recurrent Unit (GRU) cell to model temporal dependen
434445
435446- `g`: The input graph.
436447- `x`: The node features. It should be a matrix of size `in x num_nodes`.
437- - `h`: The initial hidden state of the GRU cell. If given, it is a matrix of size `out x num_nodes`.
448+ - `h`: The current state of the GRU cell. It is a matrix of size `out x num_nodes`.
438449 If not provided, it is assumed to be a matrix of zeros.
439450
440451Performs one recurrence step and returns a tuple `(h, h)`,
@@ -547,4 +558,96 @@ julia> size(y) # (d_out, timesteps, num_nodes)
547558"""
548559DCGRU (args... ; kws... ) = GNNRecurrence (DCGRUCell (args... ; kws... ))
549560
561+ """ "
562+ EvolveGCNOCell(in => out; bias = true, init = glorot_uniform)
563+
564+ Evolving Graph Convolutional Network cell of type "-O" from the paper
565+ [EvolveGCN: Evolving Graph Convolutional Networks for Dynamic Graphs](https://arxiv.org/abs/1902.10191).
566+
567+ Uses a [`GCNConv`](@ref) layer to model spatial dependencies, and an `LSTMCell` to model temporal dependencies.
568+ Can work with time-varying graphs and node features.
569+
570+ # Arguments
571+
572+ - `in => out`: A pair where `in` is the number of input node features and `out`
573+ is the number of output node features.
574+ - `bias`: Add learnable bias for the convolution and the lstm cell. Default `true`.
575+ - `init`: Weights' initializer for the convolution. Default `glorot_uniform`.
576+
577+ # Forward
578+
579+ cell(g::GNNGraph, x, [state]) -> x, state
580+
581+ - `g`: The input graph.
582+ - `x`: The node features. It should be a matrix of size `in x num_nodes`.
583+ - `state`: The current state of the cell.
584+ A state is a tuple `(weight, lstm)` where `weight` is the convolution's weight and `lstm` is the lstm's state.
585+ If not provided, it is generated by calling `Flux.initialstates(cell)`.
586+
587+ Returns the updated node features `x` and the updated state.
588+
589+ ```jldoctest
590+ julia> using GraphNeuralNetworks, Flux
591+
592+ julia> num_nodes, num_edges = 5, 10;
593+
594+ julia> d_in, d_out = 2, 3;
595+
596+ julia> timesteps = 5;
597+
598+ julia> g = [rand_graph(num_nodes, num_edges) for t in 1:timesteps];
599+
600+ julia> x = [rand(Float32, d_in, num_nodes) for t in 1:timesteps];
601+
602+ julia> cell1 = EvolveGCNOCell(d_in => d_out)
603+ EvolveGCNOCell(2 => 3) # 321 parameters
604+
605+ julia> cell2 = EvolveGCNOCell(d_out => d_out)
606+ EvolveGCNOCell(3 => 3) # 696 parameters
607+
608+ julia> state1 = Flux.initialstates(cell1);
609+
610+ julia> state2 = Flux.initialstates(cell2);
611+
612+ julia> outputs = [];
613+
614+ julia> for t in 1:timesteps
615+ zt, state1 = cell1(g[t], x[t], state1)
616+ yt, state2 = cell2(g[t], zt, state2)
617+ outputs = vcat(outputs, [yt])
618+ end
619+
620+ julia> size(outputs[end]) # (d_out, num_nodes)
621+ (3, 5)
622+ ```
623+ """
624+ struct EvolveGCNOCell{C,L} <: GNNLayer
625+ in:: Int
626+ out:: Int
627+ conv:: C
628+ lstm:: L
629+ end
630+
631+ Flux. @layer :noexpand EvolveGCNOCell
550632
633+ function EvolveGCNOCell ((in,out):: Pair{Int,Int} ; bias = true , init = glorot_uniform)
634+ conv = GCNConv (in => out; bias, init)
635+ lstm = LSTMCell (in* out => in* out; bias)
636+ return EvolveGCNOCell (in, out, conv, lstm)
637+ end
638+
639+ function Flux. initialstates (cell:: EvolveGCNOCell )
640+ weight = reshape (cell. conv. weight, :)
641+ lstm = Flux. initialstates (cell. lstm)
642+ return (; weight, lstm)
643+ end
644+
645+ function (cell:: EvolveGCNOCell )(g:: GNNGraph , x:: AbstractMatrix , state)
646+ weight, state_lstm = cell. lstm (state. weight, state. lstm)
647+ x = cell. conv (g, x, conv_weight = reshape (weight, (cell. out, cell. in)))
648+ return x, (; weight, lstm = state_lstm)
649+ end
650+
651+ function Base. show (io:: IO , egcno:: EvolveGCNOCell )
652+ print (io, " EvolveGCNOCell($(egcno. in) => $(egcno. out) )" )
653+ end
0 commit comments