Skip to content

Commit 99b25b0

Browse files
improve cuda tests
1 parent 6f3d4c8 commit 99b25b0

File tree

6 files changed

+93
-81
lines changed

6 files changed

+93
-81
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ authors = ["Carlo Lucibello and contributors"]
44
version = "0.1.0"
55

66
[deps]
7+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
78
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
89
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
910
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
@@ -25,7 +26,7 @@ ChainRulesCore = "1"
2526
DataStructures = "0.18"
2627
Flux = "0.12"
2728
KrylovKit = "0.5"
28-
LearnBase = "0.5"
29+
LearnBase = "0.4, 0.5"
2930
LightGraphs = "1.3"
3031
MacroTools = "0.5"
3132
NNlib = "0.7"

src/layers/conv.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ function GCNConv(ch::Pair{Int,Int}, σ=identity;
3232
init=glorot_uniform, bias::Bool=true)
3333
in, out = ch
3434
W = init(out, in)
35-
b = Flux.create_bias(W, bias, out)
35+
b = bias ? Flux.create_bias(W, true, out) : false
3636
GCNConv(W, b, σ)
3737
end
3838

@@ -105,7 +105,7 @@ function ChebConv(ch::Pair{Int,Int}, k::Int;
105105
init=glorot_uniform, bias::Bool=true)
106106
in, out = ch
107107
W = init(out, in, k)
108-
b = Flux.create_bias(W, bias, out)
108+
b = bias ? Flux.create_bias(W, true, out) : false
109109
ChebConv(W, b, k)
110110
end
111111

@@ -172,7 +172,7 @@ function GraphConv(ch::Pair{Int,Int}, σ=identity, aggr=+;
172172
in, out = ch
173173
W1 = init(out, in)
174174
W2 = init(out, in)
175-
b = Flux.create_bias(W1, bias, out)
175+
b = bias ? Flux.create_bias(W1, true, out) : false
176176
GraphConv(W1, W2, b, σ, aggr)
177177
end
178178

@@ -243,7 +243,7 @@ function GATConv(ch::Pair{Int,Int}, σ=identity;
243243
init=glorot_uniform, bias::Bool=true)
244244
in, out = ch
245245
W = init(out*heads, in)
246-
b = Flux.create_bias(W, bias, out*heads)
246+
b = bias ? Flux.create_bias(W, true, out*heads) : false
247247
a = init(2*out, heads)
248248
negative_slope = convert(eltype(W), negative_slope)
249249
GATConv(W, b, a, σ, negative_slope, ch, heads, concat)
@@ -479,7 +479,7 @@ end
479479
function NNConv(ch::Pair{Int,Int}, nn, σ=identity; aggr=+, bias=true, init=glorot_uniform)
480480
in, out = ch
481481
W = init(out, in)
482-
b = Flux.create_bias(W, bias, out)
482+
b = bias ? Flux.create_bias(W, true, out) : false
483483
return NNConv(W, b, nn, σ, aggr)
484484
end
485485

test/cuda/layers/conv.jl

Lines changed: 0 additions & 49 deletions
This file was deleted.

test/layers/conv.jl

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,10 @@
3030
gradtest(l, g, rtol=1e-5)
3131
end
3232

33-
# l = GCNConv(in_channel => out_channel, relu, bias=false)
34-
# for g in test_graphs
35-
# gradtest(l, g)
36-
# end
33+
l = GCNConv(in_channel => out_channel, tanh, bias=false)
34+
for g in test_graphs
35+
gradtest(l, g)
36+
end
3737
end
3838

3939

@@ -44,7 +44,8 @@
4444
@test size(l.bias) == (out_channel,)
4545
@test l.k == k
4646
for g in test_graphs
47-
gradtest(l, g, rtol=1e-5, broken_grad_fields=[:weight])
47+
gradtest(l, g, rtol=1e-5, broken_grad_fields=[:weight], test_gpu=false)
48+
@test_broken gradtest(l, g, rtol=1e-5, broken_grad_fields=[:weight], test_gpu=true)
4849
end
4950

5051
@testset "bias=false" begin
@@ -116,10 +117,17 @@
116117
@testset "NNConv" begin
117118
edim = 10
118119
nn = Dense(edim, out_channel * in_channel)
120+
119121
l = NNConv(in_channel => out_channel, nn)
120122
for g in test_graphs
121123
g = GNNGraph(g, edata=rand(T, edim, g.num_edges))
122124
gradtest(l, g, rtol=1e-5)
123125
end
126+
127+
l = NNConv(in_channel => out_channel, nn, tanh, bias=false, aggr=mean)
128+
for g in test_graphs
129+
g = GNNGraph(g, edata=rand(T, edim, g.num_edges))
130+
gradtest(l, g, rtol=1e-5)
131+
end
124132
end
125133
end

test/runtests.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ using Test
1111
CUDA.allowscalar(false)
1212

1313
include("test_utils.jl")
14-
include("cuda/test_utils.jl")
14+
# include("cuda/test_utils.jl")
1515

1616
tests = [
1717
"gnngraph",
@@ -24,13 +24,15 @@ tests = [
2424
!CUDA.functional() && @warn("CUDA unavailable, not testing GPU support")
2525

2626
# Testing all graph types. :sparse is a bit broken at the moment
27-
@testset "GraphNeuralNetworks: graph format $graph_type" for graph_type in (:coo, :sparse, :dense)
27+
@testset "GraphNeuralNetworks: graph format $graph_type" for graph_type in (:coo,)
2828

2929
global GRAPH_T = graph_type
30+
global TEST_GPU = CUDA.functional() && GRAPH_T != :sparse
31+
3032
for t in tests
3133
include("$t.jl")
3234

33-
if CUDA.functional() && GRAPH_T != :sparse && isfile("cuda/$t.jl")
35+
if TEST_GPU && isfile("cuda/$t.jl")
3436
include("cuda/$t.jl")
3537
end
3638
end

test/test_utils.jl

Lines changed: 68 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1-
using ChainRulesTestUtils, FiniteDifferences, Zygote, Adapt
1+
using ChainRulesTestUtils, FiniteDifferences, Zygote, Adapt, CUDA
2+
CUDA.allowscalar(false)
3+
4+
# global GRAPH_T = :coo
5+
# global TEST_GPU = true
26

37
const rule_config = Zygote.ZygoteRuleConfig()
48

5-
# Using this until https://github.com/JuliaDiff/FiniteDifferences.jl/issues/188
6-
# is fixed
9+
# Using this until https://github.com/JuliaDiff/FiniteDifferences.jl/issues/188 is fixed
710
function FiniteDifferences.to_vec(x::Integer)
811
Integer_from_vec(v) = x
912
return Int[x], Integer_from_vec
@@ -12,8 +15,10 @@ end
1215
function gradtest(l, g::GNNGraph; atol=1e-7, rtol=1e-5,
1316
exclude_grad_fields=[],
1417
broken_grad_fields=[],
15-
verbose = false
18+
verbose = false,
19+
test_gpu = TEST_GPU,
1620
)
21+
1722
# TODO these give errors, probably some bugs in ChainRulesTestUtils
1823
# test_rrule(rule_config, x -> l(g, x), x; rrule_f=rrule_via_ad, check_inferred=false)
1924
# test_rrule(rule_config, l -> l(g, x), l; rrule_f=rrule_via_ad, check_inferred=false)
@@ -24,75 +29,120 @@ function gradtest(l, g::GNNGraph; atol=1e-7, rtol=1e-5,
2429
x = node_features(g)
2530
e = edge_features(g)
2631

32+
x64, e64, l64, g64 = to64.([x, e, l, g])
33+
xgpu, egpu, lgpu, ggpu = gpu.([x, e, l, g])
34+
2735
f(l, g) = l(g)
28-
f(l, g, x) = isnothing(e) ? l(g, x) : l(g, x, e)
36+
f(l, g, x::AbstractArray{Float32}) = isnothing(e) ? l(g, x) : l(g, x, e)
37+
f(l, g, x::AbstractArray{Float64}) = isnothing(e64) ? l(g, x) : l(g, x, e64)
38+
f(l, g, x::CuArray) = isnothing(e64) ? l(g, x) : l(g, x, egpu)
2939

3040
loss(l, g) = sum(node_features(f(l, g)))
3141
loss(l, g, x) = sum(f(l, g, x))
3242
loss(l, g, x, e) = sum(l(g, x, e))
3343

34-
x64, e64, l64, g64 = to64.([x, e, l, g])
44+
3545
# TEST OUTPUT
3646
y = f(l, g, x)
3747
@test eltype(y) == eltype(x)
3848

3949
g′ = f(l, g)
4050
@test g′.ndata.x y
4151

42-
# TEST X INPUT GRADIENT
52+
if test_gpu
53+
ygpu = f(lgpu, ggpu, xgpu)
54+
@test ygpu isa CuArray
55+
@test eltype(ygpu) == eltype(xgpu)
56+
@test Array(ygpu) y
57+
end
58+
59+
60+
# TEST x INPUT GRADIENT
4361
= gradient(x -> loss(l, g, x), x)[1]
4462
x̄_fd = FiniteDifferences.grad(fdm, x64 -> loss(l64, g64, x64), x64)[1]
63+
@test eltype(x̄) == eltype(x)
4564
@test x̄_fd atol=atol rtol=rtol
4665

66+
if test_gpu
67+
x̄gpu = gradient(xgpu -> loss(lgpu, ggpu, xgpu), xgpu)[1]
68+
@test x̄gpu isa CuArray
69+
@test eltype(x̄gpu) == eltype(x)
70+
@test Array(x̄gpu) x̄ atol=atol rtol=rtol
71+
end
72+
73+
74+
# TEST e INPUT GRADIENT
4775
if e !== nothing
48-
# TEST E INPUT GRADIENT
4976
= gradient(e -> loss(l, g, x, e), e)[1]
5077
ē_fd = FiniteDifferences.grad(fdm, e64 -> loss(l64, g64, x64, e64), e64)[1]
78+
@test eltype(ē) == eltype(e)
5179
@test ē_fd atol=atol rtol=rtol
80+
81+
if test_gpu
82+
ēgpu = gradient(egpu -> loss(lgpu, ggpu, xgpu, egpu), egpu)[1]
83+
@test ēgpu isa CuArray
84+
@test eltype(ēgpu) == eltype(ē)
85+
@test Array(ēgpu) ē atol=atol rtol=rtol
86+
end
5287
end
5388

89+
5490
# TEST LAYER GRADIENT - l(g, x)
5591
= gradient(l -> loss(l, g, x), l)[1]
5692
l̄_fd = FiniteDifferences.grad(fdm, l64 -> loss(l64, g64, x64), l64)[1]
5793
test_approx_structs(l, l̄, l̄_fd; atol, rtol, broken_grad_fields, exclude_grad_fields, verbose)
94+
95+
if test_gpu
96+
l̄gpu = gradient(lgpu -> loss(lgpu, ggpu, xgpu), lgpu)[1]
97+
test_approx_structs(lgpu, l̄gpu, l̄; atol, rtol, broken_grad_fields, exclude_grad_fields, verbose)
98+
end
99+
58100
# TEST LAYER GRADIENT - l(g)
59101
= gradient(l -> loss(l, g), l)[1]
60102
l̄_fd = FiniteDifferences.grad(fdm, l64 -> loss(l64, g64), l64)[1]
61103
test_approx_structs(l, l̄, l̄_fd; atol, rtol, broken_grad_fields, exclude_grad_fields, verbose)
104+
105+
return true
62106
end
63107

64-
function test_approx_structs(l, l̄, l̄_fd; atol=1e-5, rtol=1e-5,
108+
function test_approx_structs(l, l̄, l̄2; atol=1e-5, rtol=1e-5,
65109
broken_grad_fields=[],
66110
exclude_grad_fields=[],
67111
verbose=false)
68112

69113
for f in fieldnames(typeof(l))
70114
f exclude_grad_fields && continue
71-
f̄, f̄_fd = getfield(l̄, f), getfield(l̄_fd, f)
115+
f̄, f̄2 = getfield(l̄, f), getfield(l̄2, f)
116+
x = getfield(l, f)
72117
if verbose
73-
println()
74-
@show f getfield(l, f) f̄_fd
75-
end
118+
println()
119+
@show f x f̄2
120+
end
76121
if isnothing(f̄)
77122
verbose && println("A")
78-
@test !(f̄_fd isa AbstractArray) || isapprox(f̄_fd, fill!(similar(f̄_fd), 0); atol=atol, rtol=rtol)
123+
@test !(f̄2 isa AbstractArray) || isapprox(f̄2, fill!(similar(f̄2), 0); atol=atol, rtol=rtol)
79124
elseifisa Union{AbstractArray, Number}
80125
verbose && println("B")
81-
@test eltype(f̄) == eltype(getfield(l, f))
126+
@test eltype(f̄) == eltype(x)
127+
if x isa CuArray
128+
@testisa CuArray
129+
= Array(f̄)
130+
end
82131
if f broken_grad_fields
83-
@test_broken f̄_fd atol=atol rtol=rtol
132+
@test_broken f̄2 atol=atol rtol=rtol
84133
else
85-
@test f̄_fd atol=atol rtol=rtol
134+
@test f̄2 atol=atol rtol=rtol
86135
end
87136
else
88137
verbose && println("C")
89-
test_approx_structs(getfield(l, f), f̄, f̄_fd; broken_grad_fields)
138+
test_approx_structs(x, f̄, f̄2; broken_grad_fields)
90139
end
91140
end
92141
return true
93142
end
94143

95144

145+
96146
"""
97147
to32(m)
98148

0 commit comments

Comments
 (0)