Skip to content

Commit d762110

Browse files
TGCCN
1 parent 23d9e45 commit d762110

File tree

3 files changed

+88
-172
lines changed

3 files changed

+88
-172
lines changed

GraphNeuralNetworks/src/GraphNeuralNetworks.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,7 @@ export GNNRecurrence,
5555
GConvLSTM, GConvLSTMCell,
5656
DCGRU, DCGRUCell,
5757
EvolveGCNO, EvolveGCNOCell,
58-
TGCN,
59-
A3TGCN
58+
TGCN, TGCNCell
6059

6160
include("layers/pool.jl")
6261
export GlobalPool,

GraphNeuralNetworks/src/layers/temporalconv.jl

Lines changed: 87 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ function (rnn::GNNRecurrence)(g, x)
130130
return rnn(g, x, initialstates(rnn))
131131
end
132132

133-
function (rnn::GNNRecurrence)(g, x, state) where {T}
133+
function (rnn::GNNRecurrence)(g, x, state)
134134
return scan(rnn.cell, g, x, state)
135135
end
136136

@@ -750,7 +750,60 @@ julia> size(y[end]) # (d_out, num_nodes[end])
750750
EvolveGCNO(args...; kws...) = GNNRecurrence(EvolveGCNOCell(args...; kws...))
751751

752752

753+
"""
754+
TGCNCell(in => out; kws...)
755+
756+
Recurrent graph convolutional cell from the paper
757+
[T-GCN: A Temporal Graph Convolutional
758+
Network for Traffic Prediction](https://arxiv.org/pdf/1811.05320).
759+
760+
Uses two stacked [`GCNConv`](@ref) layers to model spatial dependencies,
761+
and a GRU mechanism to model temporal dependencies.
762+
763+
`in` and `out` are the number of input and output node features, respectively.
764+
The keyword arguments are passed to the [`GCNConv`](@ref) constructor.
765+
766+
# Forward
767+
768+
cell(g::GNNGraph, x, [state])
769+
770+
- `g`: The input graph.
771+
- `x`: The node features. It should be a matrix of size `in x num_nodes`.
772+
- `state`: The current state of the cell.
773+
If not provided, it is generated by calling `Flux.initialstates(cell)`.
774+
The state is a matrix of size `out x num_nodes`.
775+
776+
Returns the updated node features and the updated state.
777+
778+
# Examples
779+
780+
```jldoctest
781+
julia> using GraphNeuralNetworks, Flux
782+
783+
julia> num_nodes, num_edges = 5, 10;
784+
785+
julia> d_in, d_out = 2, 3;
786+
787+
julia> timesteps = 5;
788+
789+
julia> g = rand_graph(num_nodes, num_edges);
753790
791+
julia> x = [rand(Float32, d_in, num_nodes) for t in 1:timesteps];
792+
793+
julia> cell = DCGRUCell(d_in => d_out, 2);
794+
795+
julia> state = Flux.initialstates(cell);
796+
797+
julia> y = state;
798+
799+
julia> for xt in x
800+
y, state = cell(g, xt, state)
801+
end
802+
803+
julia> size(y) # (d_out, num_nodes)
804+
(3, 5)
805+
```
806+
"""
754807
@concrete struct TGCNCell <: GNNLayer
755808
in::Int
756809
out::Int
@@ -795,4 +848,36 @@ end
795848

796849
function Base.show(io::IO, cell::TGCNCell)
797850
print(io, "TGCNCell($(cell.in) => $(cell.out))")
798-
end
851+
end
852+
853+
"""
854+
TGCN(args...; kws...)
855+
856+
Construct a recurrent layer corresponding to the [`TGCNCell`](@ref) cell.
857+
858+
The arguments are passed to the [`TGCNCell`](@ref) constructor.
859+
See [`GNNRecurrence`](@ref) for more details.
860+
861+
# Examples
862+
863+
```jldoctest
864+
julia> num_nodes, num_edges = 5, 10;
865+
866+
julia> d_in, d_out = 2, 3;
867+
868+
julia> timesteps = 5;
869+
870+
julia> g = rand_graph(num_nodes, num_edges);
871+
872+
julia> x = rand(Float32, d_in, timesteps, num_nodes);
873+
874+
julia> layer = TGCN(d_in => d_out)
875+
876+
julia> y = layer(g, x);
877+
878+
julia> size(y) # (d_out, timesteps, num_nodes)
879+
(3, 5, 5)
880+
```
881+
"""
882+
TGCN(args...; kws...) = GNNRecurrence(TGCNCell(args...; kws...))
883+

GraphNeuralNetworks/src/layers/temporalconv_old.jl

Lines changed: 0 additions & 168 deletions
This file was deleted.

0 commit comments

Comments
 (0)