Skip to content

Commit 14840e1

Browse files
improve GNN.jl testing
1 parent 530457c commit 14840e1

File tree

8 files changed

+181
-505
lines changed

8 files changed

+181
-505
lines changed

GNNGraphs/test/test_utils.jl

Lines changed: 0 additions & 221 deletions
Original file line numberDiff line numberDiff line change
@@ -5,224 +5,3 @@ function ngradient(f, x...)
55
fdm = central_fdm(5, 1)
66
return FiniteDifferences.grad(fdm, f, x...)
77
end
8-
9-
const rule_config = Zygote.ZygoteRuleConfig()
10-
11-
# Using this until https://github.com/JuliaDiff/FiniteDifferences.jl/issues/188 is fixed
12-
function FiniteDifferences.to_vec(x::Integer)
13-
Integer_from_vec(v) = x
14-
return Int[x], Integer_from_vec
15-
end
16-
17-
# Test that forward pass on cpu and gpu are the same.
18-
# Tests also gradient on cpu and gpu comparing with
19-
# finite difference methods.
20-
# Test gradients with respects to layer weights and to input.
21-
# If `g` has edge features, it is assumed that the layer can
22-
# use them in the forward pass as `l(g, x, e)`.
23-
# Test also gradient with respect to `e`.
24-
function test_layer(l, g::GNNGraph; atol = 1e-5, rtol = 1e-5,
25-
exclude_grad_fields = [],
26-
verbose = false,
27-
test_gpu = TEST_GPU,
28-
outsize = nothing,
29-
outtype = :node)
30-
31-
# TODO these give errors, probably some bugs in ChainRulesTestUtils
32-
# test_rrule(rule_config, x -> l(g, x), x; rrule_f=rrule_via_ad, check_inferred=false)
33-
# test_rrule(rule_config, l -> l(g, x), l; rrule_f=rrule_via_ad, check_inferred=false)
34-
35-
isnothing(node_features(g)) && error("Plese add node data to the input graph")
36-
fdm = central_fdm(5, 1)
37-
38-
x = node_features(g)
39-
e = edge_features(g)
40-
use_edge_feat = !isnothing(e)
41-
42-
x64, e64, l64, g64 = to64.([x, e, l, g]) # needed for accurate FiniteDifferences' grad
43-
xgpu, egpu, lgpu, ggpu = gpu.([x, e, l, g])
44-
45-
f(l, g::GNNGraph) = l(g)
46-
f(l, g::GNNGraph, x, e) = use_edge_feat ? l(g, x, e) : l(g, x)
47-
48-
loss(l, g::GNNGraph) =
49-
if outtype == :node
50-
sum(node_features(f(l, g)))
51-
elseif outtype == :edge
52-
sum(edge_features(f(l, g)))
53-
elseif outtype == :graph
54-
sum(graph_features(f(l, g)))
55-
elseif outtype == :node_edge
56-
gnew = f(l, g)
57-
sum(node_features(gnew)) + sum(edge_features(gnew))
58-
end
59-
60-
function loss(l, g::GNNGraph, x, e)
61-
y = f(l, g, x, e)
62-
if outtype == :node_edge
63-
return sum(y[1]) + sum(y[2])
64-
else
65-
return sum(y)
66-
end
67-
end
68-
69-
# TEST OUTPUT
70-
y = f(l, g, x, e)
71-
if outtype == :node_edge
72-
@assert y isa Tuple
73-
@test eltype(y[1]) == eltype(x)
74-
@test eltype(y[2]) == eltype(e)
75-
@test all(isfinite, y[1])
76-
@test all(isfinite, y[2])
77-
if !isnothing(outsize)
78-
@test size(y[1]) == outsize[1]
79-
@test size(y[2]) == outsize[2]
80-
end
81-
else
82-
@test eltype(y) == eltype(x)
83-
@test all(isfinite, y)
84-
if !isnothing(outsize)
85-
@test size(y) == outsize
86-
end
87-
end
88-
89-
# test same output on different graph formats
90-
gcoo = GNNGraph(g, graph_type = :coo)
91-
ycoo = f(l, gcoo, x, e)
92-
if outtype == :node_edge
93-
@test ycoo[1] y[1]
94-
@test ycoo[2] y[2]
95-
else
96-
@test ycoo y
97-
end
98-
99-
g′ = f(l, g)
100-
if outtype == :node
101-
@test g′.ndata.x y
102-
elseif outtype == :edge
103-
@test g′.edata.e y
104-
elseif outtype == :graph
105-
@test g′.gdata.u y
106-
elseif outtype == :node_edge
107-
@test g′.ndata.x y[1]
108-
@test g′.edata.e y[2]
109-
else
110-
@error "wrong outtype $outtype"
111-
end
112-
if test_gpu
113-
ygpu = f(lgpu, ggpu, xgpu, egpu)
114-
if outtype == :node_edge
115-
@test ygpu[1] isa CuArray
116-
@test eltype(ygpu[1]) == eltype(xgpu)
117-
@test Array(ygpu[1]) y[1]
118-
@test ygpu[2] isa CuArray
119-
@test eltype(ygpu[2]) == eltype(xgpu)
120-
@test Array(ygpu[2]) y[2]
121-
else
122-
@test ygpu isa CuArray
123-
@test eltype(ygpu) == eltype(xgpu)
124-
@test Array(ygpu) y
125-
end
126-
end
127-
128-
# TEST x INPUT GRADIENT
129-
= gradient(x -> loss(l, g, x, e), x)[1]
130-
x̄_fd = FiniteDifferences.grad(fdm, x64 -> loss(l64, g64, x64, e64), x64)[1]
131-
@test eltype(x̄) == eltype(x)
132-
@testx̄_fd atol=atol rtol=rtol
133-
134-
if test_gpu
135-
x̄gpu = gradient(xgpu -> loss(lgpu, ggpu, xgpu, egpu), xgpu)[1]
136-
@test x̄gpu isa CuArray
137-
@test eltype(x̄gpu) == eltype(x)
138-
@test Array(x̄gpu)x̄ atol=atol rtol=rtol
139-
end
140-
141-
# TEST e INPUT GRADIENT
142-
if e !== nothing
143-
verbose && println("Test e gradient cpu")
144-
= gradient(e -> loss(l, g, x, e), e)[1]
145-
ē_fd = FiniteDifferences.grad(fdm, e64 -> loss(l64, g64, x64, e64), e64)[1]
146-
@test eltype(ē) == eltype(e)
147-
@testē_fd atol=atol rtol=rtol
148-
149-
if test_gpu
150-
verbose && println("Test e gradient gpu")
151-
ēgpu = gradient(egpu -> loss(lgpu, ggpu, xgpu, egpu), egpu)[1]
152-
@test ēgpu isa CuArray
153-
@test eltype(ēgpu) == eltype(ē)
154-
@test Array(ēgpu)ē atol=atol rtol=rtol
155-
end
156-
end
157-
158-
# TEST LAYER GRADIENT - l(g, x, e)
159-
= gradient(l -> loss(l, g, x, e), l)[1]
160-
l̄_fd = FiniteDifferences.grad(fdm, l64 -> loss(l64, g64, x64, e64), l64)[1]
161-
test_approx_structs(l, l̄, l̄_fd; atol, rtol, exclude_grad_fields, verbose)
162-
163-
if test_gpu
164-
l̄gpu = gradient(lgpu -> loss(lgpu, ggpu, xgpu, egpu), lgpu)[1]
165-
test_approx_structs(lgpu, l̄gpu, l̄; atol, rtol, exclude_grad_fields, verbose)
166-
end
167-
168-
# TEST LAYER GRADIENT - l(g)
169-
= gradient(l -> loss(l, g), l)[1]
170-
test_approx_structs(l, l̄, l̄_fd; atol, rtol, exclude_grad_fields, verbose)
171-
172-
return true
173-
end
174-
175-
function test_approx_structs(l, l̄, l̄fd; atol = 1e-5, rtol = 1e-5,
176-
exclude_grad_fields = [],
177-
verbose = false)
178-
=isa Base.RefValue ? l̄[] :# Zygote wraps gradient of mutables in RefValue
179-
l̄fd = l̄fd isa Base.RefValue ? l̄fd[] : l̄fd # Zygote wraps gradient of mutables in RefValue
180-
181-
for f in fieldnames(typeof(l))
182-
f exclude_grad_fields && continue
183-
verbose && println("Test gradient of field $f...")
184-
x, g, gfd = getfield(l, f), getfield(l̄, f), getfield(l̄fd, f)
185-
test_approx_structs(x, g, gfd; atol, rtol, exclude_grad_fields, verbose)
186-
verbose && println("... field $f done!")
187-
end
188-
return true
189-
end
190-
191-
function test_approx_structs(x, g::Nothing, gfd; atol, rtol, kws...)
192-
# finite diff gradients has to be zero if present
193-
@test !(gfd isa AbstractArray) || isapprox(gfd, fill!(similar(gfd), 0); atol, rtol)
194-
end
195-
196-
function test_approx_structs(x::Union{AbstractArray, Number},
197-
g::Union{AbstractArray, Number}, gfd; atol, rtol, kws...)
198-
@test eltype(g) == eltype(x)
199-
if x isa CuArray
200-
@test g isa CuArray
201-
g = Array(g)
202-
end
203-
@test ggfd atol=atol rtol=rtol
204-
end
205-
206-
"""
207-
to32(m)
208-
209-
Convert the `eltype` of model's float parameters to `Float32`.
210-
Preserves integer arrays.
211-
"""
212-
to32(m) = _paramtype(Float32, m)
213-
214-
"""
215-
to64(m)
216-
217-
Convert the `eltype` of model's float parameters to `Float64`.
218-
Preserves integer arrays.
219-
"""
220-
to64(m) = _paramtype(Float64, m)
221-
222-
struct GNNEltypeAdaptor{T} end
223-
224-
Adapt.adapt_storage(::GNNEltypeAdaptor{T}, x::AbstractArray{<:AbstractFloat}) where T = convert(AbstractArray{T}, x)
225-
Adapt.adapt_storage(::GNNEltypeAdaptor{T}, x::AbstractArray{<:Integer}) where T = x
226-
Adapt.adapt_storage(::GNNEltypeAdaptor{T}, x::AbstractArray{<:Number}) where T = convert(AbstractArray{T}, x)
227-
228-
_paramtype(::Type{T}, m) where T = fmap(adapt(GNNEltypeAdaptor{T}()), m)

GNNlib/src/layers/conv.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,6 @@ function gin_conv(l, g::AbstractGNNGraph, x)
253253
xj, xi = expand_srcdst(g, x)
254254

255255
m = propagate(copy_xj, g, l.aggr, xj = xj)
256-
257256
return l.nn((1 .+ ofeltype(xi, l.ϵ)) .* xi .+ m)
258257
end
259258

GraphNeuralNetworks/test/layers/basic.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
Flux.testmode!(gnn)
2020

21-
test_layer(gnn, g, rtol = 1e-5, exclude_grad_fields = [, :σ²])
21+
test_gradients(gnn, g, x, rtol = 1e-5)
2222

2323
@testset "constructor with names" begin
2424
m = GNNChain(GCNConv(din => d),
@@ -53,7 +53,7 @@
5353

5454
Flux.trainmode!(gnn)
5555

56-
test_layer(gnn, g, rtol = 1e-4, atol=1e-4, exclude_grad_fields = [, :σ²])
56+
test_gradients(gnn, g, x, rtol = 1e-4, atol=1e-4)
5757
end
5858
end
5959

0 commit comments

Comments
 (0)