Skip to content

Commit 75fa31f

Browse files
authored
Add GConvGRU temporal layer (#438)
* Add `GConvGRU` * export `GConvGRU` * Add `GConGRU` test * Add example with temporal features * Fix
1 parent 217f767 commit 75fa31f

File tree

3 files changed

+107
-0
lines changed

3 files changed

+107
-0
lines changed

src/GraphNeuralNetworks.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ export
7373
# layers/temporalconv
7474
TGCN,
7575
A3TGCN,
76+
GConvGRU,
7677

7778
# layers/pool
7879
GlobalPool,

src/layers/temporalconv.jl

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,98 @@ function Base.show(io::IO, a3tgcn::A3TGCN)
187187
print(io, "A3TGCN($(a3tgcn.in) => $(a3tgcn.out))")
188188
end
189189

190+
struct GConvGRUCell <: GNNLayer
191+
conv_x_r::ChebConv
192+
conv_h_r::ChebConv
193+
conv_x_z::ChebConv
194+
conv_h_z::ChebConv
195+
conv_x_h::ChebConv
196+
conv_h_h::ChebConv
197+
k::Int
198+
state0
199+
in::Int
200+
out::Int
201+
end
202+
203+
Flux.@functor GConvGRUCell
204+
205+
function GConvGRUCell(ch::Pair{Int, Int}, k::Int, n::Int;
206+
bias::Bool = true,
207+
init = Flux.glorot_uniform,
208+
init_state = Flux.zeros32)
209+
in, out = ch
210+
# reset gate
211+
conv_x_r = ChebConv(in => out, k; bias, init)
212+
conv_h_r = ChebConv(out => out, k; bias, init)
213+
# update gate
214+
conv_x_z = ChebConv(in => out, k; bias, init)
215+
conv_h_z = ChebConv(out => out, k; bias, init)
216+
# new gate
217+
conv_x_h = ChebConv(in => out, k; bias, init)
218+
conv_h_h = ChebConv(out => out, k; bias, init)
219+
state0 = init_state(out, n)
220+
return GConvGRUCell(conv_x_r, conv_h_r, conv_x_z, conv_h_z, conv_x_h, conv_h_h, k, state0, in, out)
221+
end
222+
223+
function (ggru::GConvGRUCell)(h, g::GNNGraph, x)
224+
r = ggru.conv_x_r(g, x) .+ ggru.conv_h_r(g, h)
225+
r = Flux.sigmoid_fast(r)
226+
z = ggru.conv_x_z(g, x) .+ ggru.conv_h_z(g, h)
227+
z = Flux.sigmoid_fast(z)
228+
= ggru.conv_x_h(g, x) .+ ggru.conv_h_h(g, r .* h)
229+
= Flux.tanh_fast(h̃)
230+
h = (1 .- z) .*.+ z .* h
231+
return h, h
232+
end
233+
234+
function Base.show(io::IO, ggru::GConvGRUCell)
235+
print(io, "GConvGRUCell($(ggru.in) => $(ggru.out))")
236+
end
237+
238+
"""
239+
GConvGRU(in => out, k, n; [bias, init, init_state])
240+
241+
Graph Convolutional Gated Recurrent Unit (GConvGRU) recurrent layer from the paper [Structured Sequence Modeling with Graph Convolutional Recurrent Networks](https://arxiv.org/pdf/1612.07659).
242+
243+
Performs a layer of ChebConv to model spatial dependencies, followed by a Gated Recurrent Unit (GRU) cell to model temporal dependencies.
244+
245+
# Arguments
246+
247+
- `in`: Number of input features.
248+
- `out`: Number of output features.
249+
- `k`: Chebyshev polynomial order.
250+
- `n`: Number of nodes in the graph.
251+
- `bias`: Add learnable bias. Default `true`.
252+
- `init`: Weights' initializer. Default `glorot_uniform`.
253+
- `init_state`: Initial state of the hidden stat of the GRU layer. Default `zeros32`.
254+
255+
# Examples
256+
257+
```jldoctest
258+
julia> g1, x1 = rand_graph(5, 10), rand(Float32, 2, 5);
259+
260+
julia> ggru = GConvGRU(2 => 5, 2, g1.num_nodes);
261+
262+
julia> y = ggru(g1, x1);
263+
264+
julia> size(y)
265+
(5, 5)
266+
267+
julia> g2, x2 = rand_graph(5, 10), rand(Float32, 2, 5, 30);
268+
269+
julia> z = ggru(g2, x2);
270+
271+
julia> size(z)
272+
(5, 5, 30)
273+
```
274+
"""
275+
GConvGRU(ch, k, n; kwargs...) = Flux.Recur(GConvGRUCell(ch, k, n; kwargs...))
276+
Flux.Recur(ggru::GConvGRUCell) = Flux.Recur(ggru, ggru.state0)
277+
278+
(l::Flux.Recur{GConvGRUCell})(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g)))
279+
_applylayer(l::Flux.Recur{GConvGRUCell}, g::GNNGraph, x) = l(g, x)
280+
_applylayer(l::Flux.Recur{GConvGRUCell}, g::GNNGraph) = l(g)
281+
190282
function (l::GINConv)(tg::TemporalSnapshotsGNNGraph, x::AbstractVector)
191283
return l.(tg.snapshots, x)
192284
end

test/layers/temporalconv.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,20 @@ end
3434
@test model(g1) isa GNNGraph
3535
end
3636

37+
@testset "GConvGRUCell" begin
38+
gconvlstm = GraphNeuralNetworks.GConvGRUCell(in_channel => out_channel, 2, g1.num_nodes)
39+
h, h = gconvlstm(gconvlstm.state0, g1, g1.ndata.x)
40+
@test size(h) == (out_channel, N)
41+
end
42+
43+
@testset "GConvGRU" begin
44+
gconvlstm = GConvGRU(in_channel => out_channel, 2, g1.num_nodes)
45+
@test size(Flux.gradient(x -> sum(gconvlstm(g1, x)), g1.ndata.x)[1]) == (in_channel, N)
46+
model = GNNChain(GConvGRU(in_channel => out_channel, 2, g1.num_nodes), Dense(out_channel, 1))
47+
@test size(model(g1, g1.ndata.x)) == (1, N)
48+
@test model(g1) isa GNNGraph
49+
end
50+
3751
@testset "GINConv" begin
3852
ginconv = GINConv(Dense(in_channel => out_channel),0.3)
3953
@test length(ginconv(tg, tg.ndata.x)) == S

0 commit comments

Comments
 (0)