Skip to content

Commit 761113a

Browse files
tests passing
1 parent eeb450a commit 761113a

File tree

3 files changed

+61
-24
lines changed

3 files changed

+61
-24
lines changed

test/layers/conv.jl

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
@testset "GCNConv" begin
2828
l = GCNConv(in_channel => out_channel)
2929
for g in test_graphs
30-
gradtest(l, g)
30+
gradtest(l, g, rtol=1e-5)
3131
end
3232

3333
# l = GCNConv(in_channel => out_channel, relu, bias=false)
@@ -44,7 +44,7 @@
4444
@test size(l.bias) == (out_channel,)
4545
@test l.k == k
4646
for g in test_graphs
47-
gradtest(l, g, rtol=1) #TODO broken
47+
gradtest(l, g, rtol=1e-4, broken_grad_fields=[:weight])
4848
end
4949

5050
@testset "bias=false" begin
@@ -56,7 +56,7 @@
5656
@testset "GraphConv" begin
5757
l = GraphConv(in_channel => out_channel)
5858
for g in test_graphs
59-
gradtest(l, g, rtol=1e-3)
59+
gradtest(l, g, rtol=1e-5)
6060
end
6161

6262
# l = GraphConv(in_channel => out_channel, relu, bias=false)
@@ -72,18 +72,37 @@
7272

7373
@testset "GATConv" begin
7474

75-
for heads = [1, 2], concat = [true, false]
76-
l = GATConv(in_channel => out_channel; heads, concat)
77-
for g in test_graphs
78-
gradtest(l, g, atol=1e-3, rtol=1e-2) #TODO
79-
end
80-
81-
# l = GraphConv(in_channel => out_channel, relu, bias=false)
82-
# for g in test_graphs
83-
# gradtest(l, g)
84-
# end
75+
heads = 1
76+
concat = true
77+
l = GATConv(in_channel => out_channel; heads, concat)
78+
for g in test_graphs
79+
gradtest(l, g, rtol=1e-4)
80+
end
81+
82+
heads = 2
83+
concat = true
84+
l = GATConv(in_channel => out_channel; heads, concat)
85+
for g in test_graphs
86+
gradtest(l, g, rtol=1e-4,
87+
broken_grad_fields = [:a])
8588
end
8689

90+
heads = 1
91+
concat = false
92+
l = GATConv(in_channel => out_channel; heads, concat)
93+
for g in test_graphs
94+
gradtest(l, g, rtol=1e-4,
95+
broken_grad_fields = [:a])
96+
end
97+
98+
heads = 2
99+
concat = false
100+
l = GATConv(in_channel => out_channel; heads, concat)
101+
gradtest(l, test_graphs[1], atol=1e-4, rtol=1e-4,
102+
broken_grad_fields = [:a])
103+
gradtest(l, test_graphs[2], atol=1e-4, rtol=1e-4)
104+
105+
87106
@testset "bias=false" begin
88107
@test length(Flux.params(GATConv(2=>3))) == 3
89108
@test length(Flux.params(GATConv(2=>3, bias=false))) == 2
@@ -101,9 +120,9 @@
101120
end
102121

103122
@testset "EdgeConv" begin
104-
l = EdgeConv(Dense(2*in_channel, out_channel))
123+
l = EdgeConv(Dense(2*in_channel, out_channel), aggr=+)
105124
for g in test_graphs
106-
gradtest(l, g, atol=1e-5, rtol=1e-5)
125+
gradtest(l, g, rtol=1e-5)
107126
end
108127
end
109128

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using GraphNeuralNetworks
22
using Flux
33
using CUDA
4-
using Flux: gpu, @functor
4+
using Flux: gpu, @functor, f64, f32
55
using LinearAlgebra, Statistics, Random
66
using NNlib
77
using LearnBase

test/test_utils.jl

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,17 @@ using ChainRulesTestUtils, FiniteDifferences, Zygote
22

33
const rule_config = Zygote.ZygoteRuleConfig()
44

5-
function gradtest(l, g::GNNGraph; atol=1e-9, rtol=1e-5)
5+
# Using this until https://github.com/JuliaDiff/FiniteDifferences.jl/issues/188
6+
# is fixed
7+
function FiniteDifferences.to_vec(x::Integer)
8+
Integer_from_vec(v) = x
9+
return Int[x], Integer_from_vec
10+
end
11+
12+
function gradtest(l, g::GNNGraph; atol=1e-7, rtol=1e-5,
13+
exclude_grad_fields=[],
14+
broken_grad_fields=[]
15+
)
616
# TODO these give errors, probably some bugs in ChainRulesTestUtils
717
# test_rrule(rule_config, x -> l(g, x), x; rrule_f=rrule_via_ad, check_inferred=false)
818
# test_rrule(rule_config, l -> l(g, x), l; rrule_f=rrule_via_ad, check_inferred=false)
@@ -27,27 +37,35 @@ function gradtest(l, g::GNNGraph; atol=1e-9, rtol=1e-5)
2737
# TEST LAYER GRADIENT - l(g, x)
2838
= gradient(l -> sum(l(g, x)), l)[1]
2939
l̄_fd = FiniteDifferences.grad(fdm, l -> sum(l(g, x)), l)[1]
30-
test_approx_structs(l, l̄, l̄_fd; atol, rtol)
31-
40+
test_approx_structs(l, l̄, l̄_fd; atol, rtol, broken_grad_fields, exclude_grad_fields)
3241
# TEST LAYER GRADIENT - l(g)
3342
= gradient(l -> sum(l(g).ndata.x), l)[1]
3443
l̄_fd = FiniteDifferences.grad(fdm, l -> sum(l(g).ndata.x), l)[1]
35-
test_approx_structs(l, l̄, l̄_fd; atol, rtol)
44+
test_approx_structs(l, l̄, l̄_fd; atol, rtol, broken_grad_fields, exclude_grad_fields)
3645
end
3746

38-
function test_approx_structs(l, l̄, l̄_fd; atol=1e-9, rtol=1e-5)
47+
function test_approx_structs(l, l̄, l̄_fd; atol=1e-5, rtol=1e-5,
48+
broken_grad_fields=[],
49+
exclude_grad_fields=[])
3950
for f in fieldnames(typeof(l))
51+
f exclude_grad_fields && continue
4052
f̄, f̄_fd = getfield(l̄, f), getfield(l̄_fd, f)
4153
if isnothing(f̄)
42-
@show f̄_fd
54+
# @show f f̄_fd
4355
@test !(f̄_fd isa AbstractArray) || isapprox(f̄_fd, fill!(similar(f̄_fd), 0); atol=atol, rtol=rtol)
4456
elseifisa Union{AbstractArray, Number}
4557
@test eltype(f̄) == eltype(getfield(l, f))
46-
@test f̄_fd atol=atol rtol=rtol
58+
if f broken_grad_fields
59+
@test_broken f̄_fd atol=atol rtol=rtol
60+
else
61+
# @show f getfield(l, f) f̄ f̄_fd broken_grad_fields
62+
@test f̄_fd atol=atol rtol=rtol
63+
end
4764
else
48-
test_approx_structs(getfield(l, f), f̄, f̄_fd)
65+
test_approx_structs(getfield(l, f), f̄, f̄_fd; broken_grad_fields)
4966
end
5067
end
68+
return true
5169
end
5270

5371

0 commit comments

Comments
 (0)