Skip to content

Commit 2b2a8be

Browse files
Merge pull request #36 from CarloLucibello/cl/ecconv
add NNConv and FiniteDifferences testing
2 parents d678551 + 3004d93 commit 2b2a8be

File tree

11 files changed

+365
-278
lines changed

11 files changed

+365
-278
lines changed

Project.toml

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "GraphNeuralNetworks"
22
uuid = "cffab07f-9bc2-4db1-8861-388f63bf7694"
33
authors = ["Carlo Lucibello and contributors"]
4-
version = "0.1.0"
4+
version = "0.1.1"
55

66
[deps]
77
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
@@ -25,16 +25,19 @@ ChainRulesCore = "1"
2525
DataStructures = "0.18"
2626
Flux = "0.12"
2727
KrylovKit = "0.5"
28-
LearnBase = "0.5"
28+
LearnBase = "0.4, 0.5"
2929
LightGraphs = "1.3"
3030
MacroTools = "0.5"
3131
NNlib = "0.7"
3232
NNlibCUDA = "0.1"
3333
julia = "1.6"
3434

3535
[extras]
36+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
37+
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
38+
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
3639
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3740
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
3841

3942
[targets]
40-
test = ["Test", "Zygote"]
43+
test = ["Test", "Adapt", "Zygote", "FiniteDifferences", "ChainRulesTestUtils"]

src/GraphNeuralNetworks.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,14 @@ export
3838
GNNChain,
3939

4040
# layers/conv
41-
GCNConv,
4241
ChebConv,
43-
GraphConv,
42+
EdgeConv,
4443
GATConv,
4544
GatedGraphConv,
46-
EdgeConv,
45+
GCNConv,
4746
GINConv,
47+
GraphConv,
48+
NNConv,
4849

4950
# layers/pool
5051
GlobalPool,

src/gnngraph.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,10 @@ GNNGraph((s, t)::NTuple{2}; kws...) = GNNGraph((s, t, nothing); kws...)
151151
function GNNGraph(g::AbstractGraph; kws...)
152152
s = LightGraphs.src.(LightGraphs.edges(g))
153153
t = LightGraphs.dst.(LightGraphs.edges(g))
154+
if !LightGraphs.is_directed(g)
155+
# add reverse edges since GNNGraph are directed
156+
s, t = [s; t], [t; s]
157+
end
154158
GNNGraph((s, t); num_nodes = LightGraphs.nv(g), kws...)
155159
end
156160

src/layers/basic.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ See also [`GNNChain`](@ref).
77
"""
88
abstract type GNNLayer end
99

10-
#TODO extend to store also edge and global features
10+
# Forward pass with graph-only input.
11+
# To be specialized by layers also needing edge features as input (e.g. NNConv).
1112
(l::GNNLayer)(g::GNNGraph) = GNNGraph(g, ndata=l(g, node_features(g)))
1213

1314
"""

src/layers/conv.jl

Lines changed: 88 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ function GCNConv(ch::Pair{Int,Int}, σ=identity;
3232
init=glorot_uniform, bias::Bool=true)
3333
in, out = ch
3434
W = init(out, in)
35-
b = Flux.create_bias(W, bias, out)
35+
b = bias ? Flux.create_bias(W, true, out) : false
3636
GCNConv(W, b, σ)
3737
end
3838

@@ -105,7 +105,7 @@ function ChebConv(ch::Pair{Int,Int}, k::Int;
105105
init=glorot_uniform, bias::Bool=true)
106106
in, out = ch
107107
W = init(out, in, k)
108-
b = Flux.create_bias(W, bias, out)
108+
b = bias ? Flux.create_bias(W, true, out) : false
109109
ChebConv(W, b, k)
110110
end
111111

@@ -172,7 +172,7 @@ function GraphConv(ch::Pair{Int,Int}, σ=identity, aggr=+;
172172
in, out = ch
173173
W1 = init(out, in)
174174
W2 = init(out, in)
175-
b = Flux.create_bias(W1, bias, out)
175+
b = bias ? Flux.create_bias(W1, true, out) : false
176176
GraphConv(W1, W2, b, σ, aggr)
177177
end
178178

@@ -196,7 +196,7 @@ end
196196

197197

198198
@doc raw"""
199-
GATConv(in => out, , σ=identity;
199+
GATConv(in => out, σ=identity;
200200
heads=1,
201201
concat=true,
202202
init=glorot_uniform
@@ -224,7 +224,7 @@ with ``z_i`` a normalization factor.
224224
- `concat`: Concatenate layer output or not. If not, layer output is averaged over the heads.
225225
- `negative_slope::Real`: Keyword argument, the parameter of LeakyReLU.
226226
"""
227-
struct GATConv{T, A<:AbstractMatrix{T}, B} <: GNNLayer
227+
struct GATConv{T, A<:AbstractMatrix, B} <: GNNLayer
228228
weight::A
229229
bias::B
230230
a::A
@@ -239,12 +239,13 @@ end
239239
Flux.trainable(l::GATConv) = (l.weight, l.bias, l.a)
240240

241241
function GATConv(ch::Pair{Int,Int}, σ=identity;
242-
heads::Int=1, concat::Bool=true, negative_slope=0.2f0,
242+
heads::Int=1, concat::Bool=true, negative_slope=0.2,
243243
init=glorot_uniform, bias::Bool=true)
244244
in, out = ch
245245
W = init(out*heads, in)
246-
b = Flux.create_bias(W, bias, out*heads)
246+
b = bias ? Flux.create_bias(W, true, out*heads) : false
247247
a = init(2*out, heads)
248+
negative_slope = convert(eltype(W), negative_slope)
248249
GATConv(W, b, a, σ, negative_slope, ch, heads, concat)
249250
end
250251

@@ -356,20 +357,20 @@ end
356357

357358

358359
@doc raw"""
359-
EdgeConv(f; aggr=max)
360+
EdgeConv(nn; aggr=max)
360361
361362
Edge convolutional layer from paper [Dynamic Graph CNN for Learning on Point Clouds](https://arxiv.org/abs/1801.07829).
362363
363364
Performs the operation
364365
```math
365-
\mathbf{x}_i' = \square_{j \in N(i)} f(\mathbf{x}_i || \mathbf{x}_j - \mathbf{x}_i)
366+
\mathbf{x}_i' = \square_{j \in N(i)} nn(\mathbf{x}_i || \mathbf{x}_j - \mathbf{x}_i)
366367
```
367368
368-
where `f` typically denotes a learnable function, e.g. a linear layer or a multi-layer perceptron.
369+
where `nn` generally denotes a learnable function, e.g. a linear layer or a multi-layer perceptron.
369370
370371
# Arguments
371372
372-
- `f`: A (possibly learnable) function acting on edge features.
373+
- `nn`: A (possibly learnable) function acting on edge features.
373374
- `aggr`: Aggregation operator for the incoming messages (e.g. `+`, `*`, `max`, `min`, and `mean`).
374375
"""
375376
struct EdgeConv <: GNNLayer
@@ -405,9 +406,9 @@ Graph Isomorphism convolutional layer from paper [How Powerful are Graph Neural
405406
406407
407408
```math
408-
\mathbf{x}_i' = f\left((1 + \epsilon) \mathbf{x}_i + \sum_{j \in N(i)} \mathbf{x}_j \right)
409+
\mathbf{x}_i' = f_\Theta\left((1 + \epsilon) \mathbf{x}_i + \sum_{j \in N(i)} \mathbf{x}_j \right)
409410
```
410-
where `f` typically denotes a learnable function, e.g. a linear layer or a multi-layer perceptron.
411+
where ``f_\Theta`` typically denotes a learnable function, e.g. a linear layer or a multi-layer perceptron.
411412
412413
# Arguments
413414
@@ -434,3 +435,77 @@ function (l::GINConv)(g::GNNGraph, X::AbstractMatrix)
434435
X, _ = propagate(l, g, +, X)
435436
X
436437
end
438+
439+
440+
@doc raw"""
441+
NNConv(in => out, f, σ=identity; aggr=+, bias=true, init=glorot_uniform)
442+
443+
The continuous kernel-based convolutional operator from the
444+
[Neural Message Passing for Quantum Chemistry](https://arxiv.org/abs/1704.01212) paper.
445+
This convolution is also known as the edge-conditioned convolution from the
446+
[Dynamic Edge-Conditioned Filters in Convolutional Neural Networks on Graphs](https://arxiv.org/abs/1704.02901) paper.
447+
448+
Performs the operation
449+
450+
```math
451+
\mathbf{x}_i' = W \mathbf{x}_i + \square_{j \in N(i)} f_\Theta(\mathbf{e}_{j\to i})\,\mathbf{x}_j
452+
```
453+
454+
where ``f_\Theta`` denotes a learnable function (e.g. a linear layer or a multi-layer perceptron).
455+
Given an input of batched edge features `e` of size `(num_edge_features, num_edges)`,
456+
the function `f` will return an batched matrices array whose size is `(out, in, num_edges)`.
457+
For convenience, also functions returning a single `(out*in, num_edges)` matrix are allowed.
458+
459+
# Arguments
460+
461+
- `in`: The dimension of input features.
462+
- `out`: The dimension of output features.
463+
- `f`: A (possibly learnable) function acting on edge features.
464+
- `aggr`: Aggregation operator for the incoming messages (e.g. `+`, `*`, `max`, `min`, and `mean`).
465+
- `σ`: Activation function.
466+
- `bias`: Add learnable bias.
467+
- `init`: Weights' initializer.
468+
"""
469+
struct NNConv <: GNNLayer
470+
weight
471+
bias
472+
nn
473+
σ
474+
aggr
475+
end
476+
477+
@functor NNConv
478+
479+
function NNConv(ch::Pair{Int,Int}, nn, σ=identity; aggr=+, bias=true, init=glorot_uniform)
480+
in, out = ch
481+
W = init(out, in)
482+
b = bias ? Flux.create_bias(W, true, out) : false
483+
return NNConv(W, b, nn, σ, aggr)
484+
end
485+
486+
function compute_message(l::NNConv, x_i, x_j, e_ij)
487+
nin, nedges = size(x_i)
488+
W = reshape(l.nn(e_ij), (:, nin, nedges))
489+
x_j = reshape(x_j, (nin, 1, nedges)) # needed by batched_mul
490+
m = NNlib.batched_mul(W, x_j)
491+
return reshape(m, :, nedges)
492+
end
493+
494+
function update_node(l::NNConv, m, x)
495+
l.σ.(l.weight*x .+ m .+ l.bias)
496+
end
497+
498+
function (l::NNConv)(g::GNNGraph, x::AbstractMatrix, e)
499+
check_num_nodes(g, x)
500+
x, _ = propagate(l, g, l.aggr, x, e)
501+
return x
502+
end
503+
504+
(l::NNConv)(g::GNNGraph) = GNNGraph(g, ndata=l(g, node_features(g), edge_features(g)))
505+
506+
function Base.show(io::IO, l::NNConv)
507+
out, in = size(l.weight)
508+
print(io, "NNConv( $in => $out")
509+
print(io, ", aggr=", l.aggr)
510+
print(io, ")")
511+
end

test/cuda/layers/conv.jl

Lines changed: 0 additions & 49 deletions
This file was deleted.

test/cuda/test_utils.jl

Lines changed: 0 additions & 32 deletions
This file was deleted.

test/gnngraph.jl

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,19 @@
8484
end
8585
end
8686

87+
@testset "LightGraphs constructor" begin
88+
lg = random_regular_graph(10, 4)
89+
@test !LightGraphs.is_directed(lg)
90+
g = GNNGraph(lg)
91+
@test g.num_edges == 2*ne(lg) # g in undirected
92+
@test LightGraphs.is_directed(g)
93+
for e in LightGraphs.edges(lg)
94+
i, j = src(e), dst(e)
95+
@test has_edge(g, i, j)
96+
@test has_edge(g, j, i)
97+
end
98+
end
99+
87100
@testset "add self-loops" begin
88101
A = [1 1 0 0
89102
0 0 1 0
@@ -174,9 +187,9 @@
174187
@testset "LearnBase and DataLoader compat" begin
175188
n, m, num_graphs = 10, 30, 50
176189
X = rand(10, n)
177-
E = rand(10, m)
190+
E = rand(10, 2m)
178191
U = rand(10, 1)
179-
g = Flux.batch([GNNGraph(erdos_renyi(10, 30), ndata=rand(10, n), edata=rand(10, m), gdata=rand(10, 1))
192+
g = Flux.batch([GNNGraph(erdos_renyi(n, m), ndata=X, edata=E, gdata=U)
180193
for _ in 1:num_graphs])
181194

182195
@test LearnBase.getobs(g, 3) == getgraph(g, 3)[1]

0 commit comments

Comments
 (0)