Skip to content

Commit 43aedc2

Browse files
GConvGRU
1 parent a6700c3 commit 43aedc2

File tree

6 files changed

+752
-550
lines changed

6 files changed

+752
-550
lines changed

GraphNeuralNetworks/Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ version = "1.0.0-DEV"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
8+
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
89
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
910
GNNGraphs = "aed8fd31-079b-4b5a-b342-a13352159b8c"
1011
GNNlib = "a6a84749-d869-43f8-aacc-be26a1996e48"
@@ -18,7 +19,8 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1819

1920
[compat]
2021
ChainRulesCore = "1"
21-
Flux = "0.15"
22+
ConcreteStructs = "0.2.3"
23+
Flux = "0.16.0"
2224
GNNGraphs = "1.4"
2325
GNNlib = "1"
2426
LinearAlgebra = "1"

GraphNeuralNetworks/src/GraphNeuralNetworks.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@ module GraphNeuralNetworks
33
using Statistics: mean
44
using LinearAlgebra, Random
55
using Flux
6-
using Flux: glorot_uniform, leakyrelu, GRUCell, batch
6+
using Flux: glorot_uniform, leakyrelu, GRUCell, batch, initialstates
77
using MacroTools: @forward
88
using NNlib
99
using ChainRulesCore
1010
using Reexport: @reexport
1111
using MLUtils: zeros_like
12+
using ConcreteStructs: @concrete
1213

1314
using GNNGraphs: COO_T, ADJMAT_T, SPARSE_T,
1415
check_num_nodes, check_num_edges,
@@ -49,10 +50,10 @@ include("layers/heteroconv.jl")
4950
export HeteroGraphConv
5051

5152
include("layers/temporalconv.jl")
52-
export TGCN,
53+
export GConvGRU, GConvGRUCell,
54+
TGCN,
5355
A3TGCN,
5456
GConvLSTM,
55-
GConvGRU,
5657
DCGRU,
5758
EvolveGCNO
5859

GraphNeuralNetworks/src/layers/conv.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
# The implementations of the forward pass of the graph convolutional layers are in the `GNNlib` module,
2+
# in the src/layers/conv.jl file. The `GNNlib` module is re-exported in the GraphNeuralNetworks module.
3+
# This annoying for the readability of the code, as the user has to look at two different files to understand
4+
# the implementation of a single layer,
5+
# but it is done for GraphNeuralNetworks.jl and GNNLux.jl to be able to share the same code.
6+
17
@doc raw"""
28
GCNConv(in => out, σ=identity; [bias, init, add_self_loops, use_edge_weight])
39

GraphNeuralNetworks/src/layers/pool.jl

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -155,12 +155,6 @@ function Set2Set(n_in::Int, n_iters::Int, n_layers::Int = 1)
155155
return Set2Set(lstm, n_iters)
156156
end
157157

158-
function initialstates(cell::LSTMCell)
159-
h = zeros_like(cell.Wh, size(cell.Wh, 2))
160-
c = zeros_like(cell.Wh, size(cell.Wh, 2))
161-
return h, c
162-
end
163-
164158
function (l::Set2Set)(g, x)
165159
return GNNlib.set2set_pool(l, g, x)
166160
end

0 commit comments

Comments
 (0)