Skip to content

Commit 5587f7b

Browse files
authored
Add TGCNCell layer (#314)
* Add `TGCNCell` layer * Export * Add show function * Add docstring * Use `Flux.GRUv3Cell` * Improve
1 parent 9fccacd commit 5587f7b

File tree

4 files changed

+73
-0
lines changed

4 files changed

+73
-0
lines changed

src/GraphNeuralNetworks.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,9 @@ export
7070
# layers/heteroconv
7171
HeteroGraphConv,
7272

73+
# layers/temporalconv
74+
TGCNCell,
75+
7376
# layers/pool
7477
GlobalPool,
7578
GlobalAttentionPool,
@@ -84,6 +87,7 @@ include("utils.jl")
8487
include("layers/basic.jl")
8588
include("layers/conv.jl")
8689
include("layers/heteroconv.jl")
90+
include("layers/temporalconv.jl")
8791
include("layers/pool.jl")
8892
include("msgpass.jl")
8993
include("mldatasets.jl")

src/layers/temporalconv.jl

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
"""
2+
TGCNCell(in => out; [bias, init, add_self_loops, use_edge_weight])
3+
4+
Temporal Graph Convolutional Network (T-GCN) cell from the paper [T-GCN: A Temporal Graph Convolutional Network for Traffic Prediction](https://arxiv.org/pdf/1811.05320.pdf).
5+
6+
Performs a layer of GCNConv to model spatial dependencies, followed by a Gated Recurrent Unit (GRU) cell to model temporal dependencies.
7+
8+
# Arguments
9+
10+
- `in`: Number of input features.
11+
- `out`: Number of output features.
12+
- `bias`: Add learnable bias. Default `true`.
13+
- `init`: Weights' initializer. Default `glorot_uniform`.
14+
- `add_self_loops`: Add self loops to the graph before performing the convolution. Default `false`.
15+
- `use_edge_weight`: If `true`, consider the edge weights in the input graph (if available).
16+
If `add_self_loops=true` the new weights will be set to 1.
17+
This option is ignored if the `edge_weight` is explicitly provided in the forward pass.
18+
Default `false`.
19+
"""
20+
struct TGCNCell <: GNNLayer
21+
conv::GCNConv
22+
gru::Flux.GRUv3Cell
23+
state0
24+
in::Int
25+
out::Int
26+
end
27+
28+
Flux.@functor TGCNCell
29+
30+
function TGCNCell(ch::Pair{Int, Int};
31+
bias::Bool = true,
32+
init = Flux.glorot_uniform,
33+
init_state = Flux.zeros32,
34+
add_self_loops = false,
35+
use_edge_weight = true)
36+
in, out = ch
37+
conv = GCNConv(in => out, sigmoid; init, bias, add_self_loops,
38+
use_edge_weight)
39+
gru = Flux.GRUv3Cell(out, out)
40+
state0 = init_state(out,1)
41+
return TGCNCell(conv, gru, state0, in,out)
42+
end
43+
44+
function (tgcn::TGCNCell)(h, g::GNNGraph, x::AbstractArray)
45+
= tgcn.conv(g, x)
46+
h, x̃ = tgcn.gru(h, x̃)
47+
return h, x̃
48+
end
49+
50+
function Base.show(io::IO, tgcn::TGCNCell)
51+
print(io, "TGCNCell($(tgcn.in) => $(tgcn.out))")
52+
end

test/layers/temporalconv.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
in_channel = 3
2+
out_channel = 5
3+
N = 4
4+
T = Float32
5+
6+
g1 = GNNGraph(rand_graph(N,8),
7+
ndata = rand(T, in_channel, N),
8+
graph_type = :sparse)
9+
10+
@testset "TGCNCell" begin
11+
tgcn = TGCNCell(in_channel => out_channel)
12+
h, x̃ = tgcn(tgcn.state0, g1, g1.ndata.x)
13+
@test size(h) == (out_channel, N)
14+
@test size(x̃) == (out_channel, N)
15+
@test h ==
16+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ tests = [
4141
"layers/basic",
4242
"layers/conv",
4343
"layers/heteroconv",
44+
"layers/temporalconv",
4445
"layers/pool",
4546
"mldatasets",
4647
"examples/node_classification_cora",

0 commit comments

Comments
 (0)