Skip to content

Commit 23c429e

Browse files
Add Flux.Recur and TGCN (#319)
* Add `TGCN` and `Flux.Recur` * Add docstring `TGCN` * Add comment * Specify type * Add test * FIx identation * Add spaces Co-authored-by: Carlo Lucibello <[email protected]> * Improved docstring Co-authored-by: Carlo Lucibello <[email protected]> * Fix docstring Co-authored-by: Carlo Lucibello <[email protected]> * Add `_applylayer` --------- Co-authored-by: Carlo Lucibello <[email protected]>
1 parent 1a4c62b commit 23c429e

File tree

3 files changed

+79
-20
lines changed

3 files changed

+79
-20
lines changed

src/GraphNeuralNetworks.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ export
7171
HeteroGraphConv,
7272

7373
# layers/temporalconv
74-
TGCNCell,
74+
TGCN,
7575

7676
# layers/pool
7777
GlobalPool,

src/layers/temporalconv.jl

Lines changed: 69 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,15 @@
1-
"""
2-
TGCNCell(in => out; [bias, init, add_self_loops, use_edge_weight])
3-
4-
Temporal Graph Convolutional Network (T-GCN) cell from the paper [T-GCN: A Temporal Graph Convolutional Network for Traffic Prediction](https://arxiv.org/pdf/1811.05320.pdf).
5-
6-
Performs a layer of GCNConv to model spatial dependencies, followed by a Gated Recurrent Unit (GRU) cell to model temporal dependencies.
7-
8-
# Arguments
1+
# Adapting Flux.Recur to work with GNNGraphs
2+
function (m::Flux.Recur)(g::GNNGraph, x)
3+
m.state, y = m.cell(m.state, g, x)
4+
return y
5+
end
6+
7+
function (m::Flux.Recur)(g::GNNGraph, x::AbstractArray{T, 3}) where T
8+
h = [m(g, x_t) for x_t in Flux.eachlastdim(x)]
9+
sze = size(h[1])
10+
reshape(reduce(hcat, h), sze[1], sze[2], length(h))
11+
end
912

10-
- `in`: Number of input features.
11-
- `out`: Number of output features.
12-
- `bias`: Add learnable bias. Default `true`.
13-
- `init`: Weights' initializer. Default `glorot_uniform`.
14-
- `add_self_loops`: Add self loops to the graph before performing the convolution. Default `false`.
15-
- `use_edge_weight`: If `true`, consider the edge weights in the input graph (if available).
16-
If `add_self_loops=true` the new weights will be set to 1.
17-
This option is ignored if the `edge_weight` is explicitly provided in the forward pass.
18-
Default `false`.
19-
"""
2013
struct TGCNCell <: GNNLayer
2114
conv::GCNConv
2215
gru::Flux.GRUv3Cell
@@ -50,3 +43,61 @@ end
5043
function Base.show(io::IO, tgcn::TGCNCell)
5144
print(io, "TGCNCell($(tgcn.in) => $(tgcn.out))")
5245
end
46+
47+
"""
48+
TGCN(in => out; [bias, init, add_self_loops, use_edge_weight])
49+
50+
Temporal Graph Convolutional Network (T-GCN) recurrent layer from the paper [T-GCN: A Temporal Graph Convolutional Network for Traffic Prediction](https://arxiv.org/pdf/1811.05320.pdf).
51+
52+
Performs a layer of GCNConv to model spatial dependencies, followed by a Gated Recurrent Unit (GRU) cell to model temporal dependencies.
53+
54+
# Arguments
55+
56+
- `in`: Number of input features.
57+
- `out`: Number of output features.
58+
- `bias`: Add learnable bias. Default `true`.
59+
- `init`: Weights' initializer. Default `glorot_uniform`.
60+
- `add_self_loops`: Add self loops to the graph before performing the convolution. Default `false`.
61+
- `use_edge_weight`: If `true`, consider the edge weights in the input graph (if available).
62+
If `add_self_loops=true` the new weights will be set to 1.
63+
This option is ignored if the `edge_weight` is explicitly provided in the forward pass.
64+
Default `false`.
65+
# Examples
66+
67+
```jldoctest
68+
julia> tgcn = TGCN(2 => 6)
69+
Recur(
70+
TGCNCell(
71+
GCNConv(2 => 6, σ), # 18 parameters
72+
GRUv3Cell(6 => 6), # 240 parameters
73+
Float32[0.0; 0.0; … ; 0.0; 0.0;;], # 6 parameters (all zero)
74+
2,
75+
6,
76+
),
77+
) # Total: 8 trainable arrays, 264 parameters,
78+
# plus 1 non-trainable, 6 parameters, summarysize 1.492 KiB.
79+
80+
julia> g, x = rand_graph(5, 10), rand(Float32, 2, 5);
81+
82+
julia> y = tgcn(g, x);
83+
84+
julia> size(y)
85+
(6, 5)
86+
87+
julia> Flux.reset!(tgcn);
88+
89+
julia> tgcn(rand_graph(5, 10), rand(Float32, 2, 5, 20)) |> size # batch size of 20
90+
(6, 5, 20)
91+
```
92+
93+
!!! warning "Batch size changes"
94+
Failing to call `reset!` when the input batch size changes can lead to unexpected behavior.
95+
"""
96+
TGCN(ch; kwargs...) = Flux.Recur(TGCNCell(ch; kwargs...))
97+
98+
Flux.Recur(tgcn::TGCNCell) = Flux.Recur(tgcn, tgcn.state0)
99+
100+
# make TGCN compatible with GNNChain
101+
(l::Flux.Recur{TGCNCell})(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g)))
102+
_applylayer(l::Flux.Recur{TGCNCell}, g::GNNGraph, x) = l(g, x)
103+
_applylayer(l::Flux.Recur{TGCNCell}, g::GNNGraph) = l(g)

test/layers/temporalconv.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,17 @@ g1 = GNNGraph(rand_graph(N,8),
88
graph_type = :sparse)
99

1010
@testset "TGCNCell" begin
11-
tgcn = TGCNCell(in_channel => out_channel)
11+
tgcn = GraphNeuralNetworks.TGCNCell(in_channel => out_channel)
1212
h, x̃ = tgcn(tgcn.state0, g1, g1.ndata.x)
1313
@test size(h) == (out_channel, N)
1414
@test size(x̃) == (out_channel, N)
1515
@test h ==
16+
end
17+
18+
@testset "TGCN" begin
19+
tgcn = TGCN(in_channel => out_channel)
20+
@test size(Flux.gradient(x -> sum(tgcn(g1, x)), g1.ndata.x)[1]) == (in_channel, N)
21+
model = GNNChain(TGCN(in_channel => out_channel), Dense(out_channel, 1))
22+
@test size(model(g1, g1.ndata.x)) == (1, N)
23+
@test model(g1) isa GNNGraph
1624
end

0 commit comments

Comments
 (0)