Skip to content

Commit 27d13c8

Browse files
rewrite recurrent temporal layers for Flux v0.16 (#560)
* GConvGRU * GConvLSTM * GNNRecurrence * EvolveGCNOCell * cleanup * EvolveGCNO * TGCNCell * TGCCN * tests * fix gatedgraphconv * fix set2set
1 parent bbff8a9 commit 27d13c8

File tree

10 files changed

+1019
-618
lines changed

10 files changed

+1019
-618
lines changed

GNNLux/src/layers/conv.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1261,7 +1261,11 @@ LuxCore.parameterlength(l::GatedGraphConv) = parameterlength(l.gru) + l.dims^2*l
12611261

12621262
function (l::GatedGraphConv)(g, x, ps, st)
12631263
gru = StatefulLuxLayer{true}(l.gru, ps.gru, _getstate(st, :gru))
1264-
fgru = (x, h) -> gru((x, (h,)))[1] # make the forward compatible with Flux.GRUCell style
1264+
# make the forward compatible with Flux.GRUCell style
1265+
function fgru(x, h)
1266+
y, (h, ) = gru((x, (h,)))
1267+
return y, h
1268+
end
12651269
m = (; gru=fgru, ps.weight, l.num_layers, l.aggr, l.dims)
12661270
return GNNlib.gated_graph_conv(m, g, x), st
12671271
end

GNNlib/src/layers/conv.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -227,8 +227,7 @@ function gated_graph_conv(l, g::GNNGraph, x::AbstractMatrix)
227227
for i in 1:(l.num_layers)
228228
m = view(l.weight, :, :, i) * h
229229
m = propagate(copy_xj, g, l.aggr; xj = m)
230-
# in gru forward, hidden state is first argument, input is second
231-
h = l.gru(m, h)
230+
_, h = l.gru(m, h)
232231
end
233232
return h
234233
end

GNNlib/src/layers/pool.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@ function set2set_pool(l, g::GNNGraph, x::AbstractMatrix)
3131
qstar = zeros_like(x, (2*n_in, g.num_graphs))
3232
h = zeros_like(l.lstm.Wh, size(l.lstm.Wh, 2))
3333
c = zeros_like(l.lstm.Wh, size(l.lstm.Wh, 2))
34+
state = (h, c)
3435
for t in 1:l.num_iters
35-
h, c = l.lstm(qstar, (h, c)) # [n_in, n_graphs]
36-
q = h
36+
q, state = l.lstm(qstar, state) # [n_in, n_graphs]
3737
qn = broadcast_nodes(g, q) # [n_in, n_nodes]
3838
α = softmax_nodes(g, sum(qn .* x, dims = 1)) # [1, n_nodes]
3939
r = reduce_nodes(+, g, x .* α) # [n_in, n_graphs]

GraphNeuralNetworks/Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ version = "1.0.0-DEV"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
8+
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
89
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
910
GNNGraphs = "aed8fd31-079b-4b5a-b342-a13352159b8c"
1011
GNNlib = "a6a84749-d869-43f8-aacc-be26a1996e48"
@@ -18,7 +19,8 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1819

1920
[compat]
2021
ChainRulesCore = "1"
21-
Flux = "0.15"
22+
ConcreteStructs = "0.2.3"
23+
Flux = "0.16.0"
2224
GNNGraphs = "1.4"
2325
GNNlib = "1"
2426
LinearAlgebra = "1"

GraphNeuralNetworks/src/GraphNeuralNetworks.jl

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@ module GraphNeuralNetworks
33
using Statistics: mean
44
using LinearAlgebra, Random
55
using Flux
6-
using Flux: glorot_uniform, leakyrelu, GRUCell, batch
6+
using Flux: glorot_uniform, leakyrelu, GRUCell, batch, initialstates
77
using MacroTools: @forward
88
using NNlib
99
using ChainRulesCore
1010
using Reexport: @reexport
1111
using MLUtils: zeros_like
12+
using ConcreteStructs: @concrete
1213

1314
using GNNGraphs: COO_T, ADJMAT_T, SPARSE_T,
1415
check_num_nodes, check_num_edges,
@@ -49,12 +50,12 @@ include("layers/heteroconv.jl")
4950
export HeteroGraphConv
5051

5152
include("layers/temporalconv.jl")
52-
export TGCN,
53-
A3TGCN,
54-
GConvLSTM,
55-
GConvGRU,
56-
DCGRU,
57-
EvolveGCNO
53+
export GNNRecurrence,
54+
GConvGRU, GConvGRUCell,
55+
GConvLSTM, GConvLSTMCell,
56+
DCGRU, DCGRUCell,
57+
EvolveGCNO, EvolveGCNOCell,
58+
TGCN, TGCNCell
5859

5960
include("layers/pool.jl")
6061
export GlobalPool,

GraphNeuralNetworks/src/layers/conv.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
# The implementations of the forward pass of the graph convolutional layers are in the `GNNlib` module,
2+
# in the src/layers/conv.jl file. The `GNNlib` module is re-exported in the GraphNeuralNetworks module.
3+
# This annoying for the readability of the code, as the user has to look at two different files to understand
4+
# the implementation of a single layer,
5+
# but it is done for GraphNeuralNetworks.jl and GNNLux.jl to be able to share the same code.
6+
17
@doc raw"""
28
GCNConv(in => out, σ=identity; [bias, init, add_self_loops, use_edge_weight])
39

GraphNeuralNetworks/src/layers/pool.jl

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -155,12 +155,6 @@ function Set2Set(n_in::Int, n_iters::Int, n_layers::Int = 1)
155155
return Set2Set(lstm, n_iters)
156156
end
157157

158-
function initialstates(cell::LSTMCell)
159-
h = zeros_like(cell.Wh, size(cell.Wh, 2))
160-
c = zeros_like(cell.Wh, size(cell.Wh, 2))
161-
return h, c
162-
end
163-
164158
function (l::Set2Set)(g, x)
165159
return GNNlib.set2set_pool(l, g, x)
166160
end

0 commit comments

Comments
 (0)