Skip to content

Commit b4b9020

Browse files
add NNConv tests
1 parent 761113a commit b4b9020

File tree

6 files changed

+113
-65
lines changed

6 files changed

+113
-65
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ authors = ["Carlo Lucibello and contributors"]
44
version = "0.1.0"
55

66
[deps]
7+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
78
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
89
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
910
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"

src/GraphNeuralNetworks.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,14 @@ export
3838
GNNChain,
3939

4040
# layers/conv
41-
GCNConv,
4241
ChebConv,
43-
GraphConv,
42+
EdgeConv,
4443
GATConv,
4544
GatedGraphConv,
46-
EdgeConv,
45+
GCNConv,
4746
GINConv,
47+
GraphConv,
48+
NNConv,
4849

4950
# layers/pool
5051
GlobalPool,

src/layers/basic.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ See also [`GNNChain`](@ref).
77
"""
88
abstract type GNNLayer end
99

10-
#TODO extend to store also edge and global features
10+
# Forward pass with graph-only input.
11+
# To be specialized by layers also needing edge features as input (e.g. NNConv).
1112
(l::GNNLayer)(g::GNNGraph) = GNNGraph(g, ndata=l(g, node_features(g)))
1213

1314
"""

src/layers/conv.jl

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ end
196196

197197

198198
@doc raw"""
199-
GATConv(in => out, , σ=identity;
199+
GATConv(in => out, σ=identity;
200200
heads=1,
201201
concat=true,
202202
init=glorot_uniform
@@ -224,7 +224,7 @@ with ``z_i`` a normalization factor.
224224
- `concat`: Concatenate layer output or not. If not, layer output is averaged over the heads.
225225
- `negative_slope::Real`: Keyword argument, the parameter of LeakyReLU.
226226
"""
227-
struct GATConv{T, A<:AbstractMatrix{T}, B} <: GNNLayer
227+
struct GATConv{T, A<:AbstractMatrix, B} <: GNNLayer
228228
weight::A
229229
bias::B
230230
a::A
@@ -239,12 +239,13 @@ end
239239
Flux.trainable(l::GATConv) = (l.weight, l.bias, l.a)
240240

241241
function GATConv(ch::Pair{Int,Int}, σ=identity;
242-
heads::Int=1, concat::Bool=true, negative_slope=0.2f0,
242+
heads::Int=1, concat::Bool=true, negative_slope=0.2,
243243
init=glorot_uniform, bias::Bool=true)
244244
in, out = ch
245245
W = init(out*heads, in)
246246
b = Flux.create_bias(W, bias, out*heads)
247247
a = init(2*out, heads)
248+
negative_slope = convert(eltype(W), negative_slope)
248249
GATConv(W, b, a, σ, negative_slope, ch, heads, concat)
249250
end
250251

@@ -437,7 +438,7 @@ end
437438

438439

439440
@doc raw"""
440-
NNConv(in => out, σ=identity; aggr=+, bias=true, init=glorot_uniform)
441+
NNConv(in => out, f, σ=identity; aggr=+, bias=true, init=glorot_uniform)
441442
442443
The continuous kernel-based convolutional operator from the
443444
[Neural Message Passing for Quantum Chemistry](https://arxiv.org/abs/1704.01212) paper.
@@ -447,7 +448,7 @@ This convolution is also known as the edge-conditioned convolution from the
447448
Performs the operation
448449
449450
```math
450-
\mathbf{x}_i' = W x_i + \square_{j \in N(i)} f_\Theta(\mathbf{e}_{j\to i})\,\mathbf{x}_j
451+
\mathbf{x}_i' = W \mathbf{x}_i + \square_{j \in N(i)} f_\Theta(\mathbf{e}_{j\to i})\,\mathbf{x}_j
451452
```
452453
453454
where ``f_\Theta`` denotes a learnable function (e.g. a linear layer or a multi-layer perceptron).
@@ -459,6 +460,7 @@ For convenience, also functions returning a single `(out*in, num_edges)` matrix
459460
460461
- `in`: The dimension of input features.
461462
- `out`: The dimension of output features.
463+
- ``f``: A (possibly learnable) function acting on edge features.
462464
- `aggr`: Aggregation operator for the incoming messages (e.g. `+`, `*`, `max`, `min`, and `mean`).
463465
- `σ`: Activation function.
464466
- `bias`: Add learnable bias.
@@ -468,32 +470,39 @@ struct NNConv <: GNNLayer
468470
weight
469471
bias
470472
nn
473+
σ
471474
aggr
472475
end
473476

474477
@functor NNConv
475478

476-
function NNConv(ch::Pair{Int,Int}, σ=identity; aggr=+, bias=true, init=glorot_uniform)
479+
function NNConv(ch::Pair{Int,Int}, nn, σ=identity; aggr=+, bias=true, init=glorot_uniform)
477480
in, out = ch
478481
W = init(out, in)
479482
b = Flux.create_bias(W, bias, out)
480-
return NNConv(W, b, nn, aggr)
483+
return NNConv(W, b, nn, σ, aggr)
481484
end
482485

483486
function compute_message(l::NNConv, x_i, x_j, e_ij)
484487
nin, nedges = size(x_i)
485488
W = reshape(l.nn(e_ij), (:, nin, nedges))
486-
return NNlib.batched_mul(W, x_j)
489+
x_j = reshape(x_j, (nin, 1, nedges)) # needed by batched_mul
490+
m = NNlib.batched_mul(W, x_j)
491+
return reshape(m, :, nedges)
487492
end
488493

489-
update_node(l::NNConv, m, x) = l.weight*x + m
494+
function update_node(l::NNConv, m, x)
495+
l.σ.(l.weight*x .+ m .+ l.bias)
496+
end
490497

491498
function (l::NNConv)(g::GNNGraph, x::AbstractMatrix, e)
492-
check_num_nodes(g, X)
499+
check_num_nodes(g, x)
493500
x, _ = propagate(l, g, l.aggr, x, e)
494-
return l.σ.(x + l.bias)
501+
return x
495502
end
496503

504+
(l::NNConv)(g::GNNGraph) = GNNGraph(g, ndata=l(g, node_features(g), edge_features(g)))
505+
497506
function Base.show(io::IO, l::NNConv)
498507
out, in = size(l.weight)
499508
print(io, "NNConv( $in => $out")

test/layers/conv.jl

Lines changed: 18 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
@test size(l.bias) == (out_channel,)
4545
@test l.k == k
4646
for g in test_graphs
47-
gradtest(l, g, rtol=1e-4, broken_grad_fields=[:weight])
47+
gradtest(l, g, rtol=1e-5, broken_grad_fields=[:weight])
4848
end
4949

5050
@testset "bias=false" begin
@@ -72,37 +72,13 @@
7272

7373
@testset "GATConv" begin
7474

75-
heads = 1
76-
concat = true
77-
l = GATConv(in_channel => out_channel; heads, concat)
78-
for g in test_graphs
79-
gradtest(l, g, rtol=1e-4)
80-
end
81-
82-
heads = 2
83-
concat = true
84-
l = GATConv(in_channel => out_channel; heads, concat)
85-
for g in test_graphs
86-
gradtest(l, g, rtol=1e-4,
87-
broken_grad_fields = [:a])
88-
end
89-
90-
heads = 1
91-
concat = false
92-
l = GATConv(in_channel => out_channel; heads, concat)
93-
for g in test_graphs
94-
gradtest(l, g, rtol=1e-4,
95-
broken_grad_fields = [:a])
75+
for heads in (1, 2), concat in (true, false)
76+
l = GATConv(in_channel => out_channel; heads, concat)
77+
for g in test_graphs
78+
gradtest(l, g, rtol=1e-4)
79+
end
9680
end
9781

98-
heads = 2
99-
concat = false
100-
l = GATConv(in_channel => out_channel; heads, concat)
101-
gradtest(l, test_graphs[1], atol=1e-4, rtol=1e-4,
102-
broken_grad_fields = [:a])
103-
gradtest(l, test_graphs[2], atol=1e-4, rtol=1e-4)
104-
105-
10682
@testset "bias=false" begin
10783
@test length(Flux.params(GATConv(2=>3))) == 3
10884
@test length(Flux.params(GATConv(2=>3, bias=false))) == 2
@@ -115,7 +91,7 @@
11591
@test size(l.weight) == (out_channel, out_channel, num_layers)
11692

11793
for g in test_graphs
118-
gradtest(l, g, atol=1e-5, rtol=1e-5)
94+
gradtest(l, g, rtol=1e-5)
11995
end
12096
end
12197

@@ -131,9 +107,19 @@
131107
eps = 0.001f0
132108
l = GINConv(nn, eps=eps)
133109
for g in test_graphs
134-
gradtest(l, g, atol=1e-5, rtol=1e-5)
110+
gradtest(l, g, rtol=1e-5, exclude_grad_fields=[:eps])
135111
end
136112

137113
@test !in(:eps, Flux.trainable(l))
138114
end
115+
116+
@testset "NNConv" begin
117+
edim = 10
118+
nn = Dense(edim, out_channel * in_channel)
119+
l = NNConv(in_channel => out_channel, nn)
120+
for g in test_graphs
121+
g = GNNGraph(g, edata=rand(T, edim, g.num_edges))
122+
gradtest(l, g, rtol=1e-5)
123+
end
124+
end
139125
end

test/test_utils.jl

Lines changed: 68 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using ChainRulesTestUtils, FiniteDifferences, Zygote
1+
using ChainRulesTestUtils, FiniteDifferences, Zygote, Adapt
22

33
const rule_config = Zygote.ZygoteRuleConfig()
44

@@ -11,64 +11,114 @@ end
1111

1212
function gradtest(l, g::GNNGraph; atol=1e-7, rtol=1e-5,
1313
exclude_grad_fields=[],
14-
broken_grad_fields=[]
14+
broken_grad_fields=[],
15+
verbose = false
1516
)
1617
# TODO these give errors, probably some bugs in ChainRulesTestUtils
1718
# test_rrule(rule_config, x -> l(g, x), x; rrule_f=rrule_via_ad, check_inferred=false)
1819
# test_rrule(rule_config, l -> l(g, x), l; rrule_f=rrule_via_ad, check_inferred=false)
1920

20-
!haskey(g.ndata, :x) && error("Plese pass input graph with :x ndata")
21+
isnothing(node_features(g)) && error("Plese add node data to the input graph")
2122
fdm = central_fdm(5, 1)
2223

23-
x = g.ndata.x
24+
x = node_features(g)
25+
e = edge_features(g)
26+
27+
f(l, g) = l(g)
28+
f(l, g, x) = isnothing(e) ? l(g, x) : l(g, x, e)
29+
30+
loss(l, g) = sum(node_features(f(l, g)))
31+
loss(l, g, x) = sum(f(l, g, x))
32+
loss(l, g, x, e) = sum(l(g, x, e))
2433

34+
x64, e64, l64, g64 = to64.([x, e, l, g])
2535
# TEST OUTPUT
26-
y = l(g, x)
36+
y = f(l, g, x)
2737
@test eltype(y) == eltype(x)
2838

29-
g′ = l(g)
39+
g′ = f(l, g)
3040
@test g′.ndata.x y
3141

32-
# TEST INPUT GRADIENT
33-
= gradient(x -> sum(l(g, x)), x)[1]
34-
x̄_fd = FiniteDifferences.grad(fdm, x -> sum(l(g, x)), x)[1]
42+
# TEST X INPUT GRADIENT
43+
= gradient(x -> loss(l, g, x), x)[1]
44+
x̄_fd = FiniteDifferences.grad(fdm, x64 -> loss(l64, g64, x64), x64)[1]
3545
@test x̄_fd atol=atol rtol=rtol
3646

47+
if e !== nothing
48+
# TEST E INPUT GRADIENT
49+
= gradient(e -> loss(l, g, x, e), e)[1]
50+
ē_fd = FiniteDifferences.grad(fdm, e64 -> loss(l64, g64, x64, e64), e64)[1]
51+
@test ē_fd atol=atol rtol=rtol
52+
end
53+
3754
# TEST LAYER GRADIENT - l(g, x)
38-
= gradient(l -> sum(l(g, x)), l)[1]
39-
l̄_fd = FiniteDifferences.grad(fdm, l -> sum(l(g, x)), l)[1]
40-
test_approx_structs(l, l̄, l̄_fd; atol, rtol, broken_grad_fields, exclude_grad_fields)
55+
= gradient(l -> loss(l, g, x), l)[1]
56+
l̄_fd = FiniteDifferences.grad(fdm, l64 -> loss(l64, g64, x64), l64)[1]
57+
test_approx_structs(l, l̄, l̄_fd; atol, rtol, broken_grad_fields, exclude_grad_fields, verbose)
4158
# TEST LAYER GRADIENT - l(g)
42-
= gradient(l -> sum(l(g).ndata.x), l)[1]
43-
l̄_fd = FiniteDifferences.grad(fdm, l -> sum(l(g).ndata.x), l)[1]
44-
test_approx_structs(l, l̄, l̄_fd; atol, rtol, broken_grad_fields, exclude_grad_fields)
59+
= gradient(l -> loss(l, g), l)[1]
60+
l̄_fd = FiniteDifferences.grad(fdm, l64 -> loss(l64, g64), l64)[1]
61+
test_approx_structs(l, l̄, l̄_fd; atol, rtol, broken_grad_fields, exclude_grad_fields, verbose)
4562
end
4663

4764
function test_approx_structs(l, l̄, l̄_fd; atol=1e-5, rtol=1e-5,
4865
broken_grad_fields=[],
49-
exclude_grad_fields=[])
66+
exclude_grad_fields=[],
67+
verbose=false)
68+
5069
for f in fieldnames(typeof(l))
5170
f exclude_grad_fields && continue
5271
f̄, f̄_fd = getfield(l̄, f), getfield(l̄_fd, f)
72+
if verbose
73+
println()
74+
@show f getfield(l, f) f̄ f̄_fd
75+
end
5376
if isnothing(f̄)
54-
# @show f f̄_fd
77+
verbose && println("A")
5578
@test !(f̄_fd isa AbstractArray) || isapprox(f̄_fd, fill!(similar(f̄_fd), 0); atol=atol, rtol=rtol)
5679
elseifisa Union{AbstractArray, Number}
80+
verbose && println("B")
5781
@test eltype(f̄) == eltype(getfield(l, f))
5882
if f broken_grad_fields
5983
@test_broken f̄_fd atol=atol rtol=rtol
6084
else
61-
# @show f getfield(l, f) f̄ f̄_fd broken_grad_fields
6285
@test f̄_fd atol=atol rtol=rtol
6386
end
6487
else
88+
verbose && println("C")
6589
test_approx_structs(getfield(l, f), f̄, f̄_fd; broken_grad_fields)
6690
end
6791
end
6892
return true
6993
end
7094

7195

96+
"""
97+
to32(m)
98+
99+
Convert the `eltype` of model's parameters to `Float32` or `Int32`.
100+
"""
101+
function to32(m)
102+
f(x::AbstractArray) = eltype(x) <: Integer ? adapt(Int32, x) : adapt(Float32, x)
103+
f(x::Number) = typeof(x) <: Integer ? adapt(Int32, x) : adapt(Float32, x)
104+
f(x) = adapt(Float32, x)
105+
return fmap(f, m)
106+
end
107+
108+
"""
109+
to64(m)
110+
111+
Convert the `eltype` of model's parameters to `Float64` or `Int64`.
112+
"""
113+
function to64(m)
114+
f(x::AbstractArray) = eltype(x) <: Integer ? adapt(Int64, x) : adapt(Float64, x)
115+
f(x::Number) = typeof(x) <: Integer ? adapt(Int64, x) : adapt(Float64, x)
116+
f(x) = adapt(Float64, x)
117+
return fmap(f, m)
118+
end
119+
120+
121+
72122
# function gpu_gradtest(l, x_cpu = nothing, args...; test_cpu = true)
73123
# isnothing(x_cpu) && error("Missing input to test the layers against.")
74124
# @testset "$name GPU grad tests" begin

0 commit comments

Comments
 (0)