Skip to content

Commit c653600

Browse files
fix tests
1 parent 8dae39c commit c653600

File tree

2 files changed

+65
-26
lines changed

2 files changed

+65
-26
lines changed

test/layers/conv.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,10 +186,12 @@
186186
end
187187

188188
@testset "MEGNetConv" begin
189-
l = MEGNetConv(in_channel => out_channel, tanh, aggr=+)
189+
l = MEGNetConv(in_channel => out_channel, aggr=+)
190190
for g in test_graphs
191191
g = GNNGraph(g, edata=rand(T, in_channel, g.num_edges))
192-
test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes))
192+
test_layer(l, g, rtol=1e-5,
193+
outtype=:node_edge,
194+
outsize=((out_channel, g.num_nodes), (out_channel, g.num_edges)))
193195
end
194196
end
195197
end

test/test_utils.jl

Lines changed: 61 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -30,66 +30,103 @@ function test_layer(l, g::GNNGraph; atol = 1e-6, rtol = 1e-5,
3030

3131
x = node_features(g)
3232
e = edge_features(g)
33+
use_edge_feat = !isnothing(e)
3334

3435
x64, e64, l64, g64 = to64.([x, e, l, g]) # needed for accurate FiniteDifferences' grad
3536
xgpu, egpu, lgpu, ggpu = gpu.([x, e, l, g])
3637

3738
f(l, g::GNNGraph) = l(g)
38-
f(l, g::GNNGraph, x::AbstractArray{Float32}) = isnothing(e) ? l(g, x) : l(g, x, e)
39-
f(l, g::GNNGraph, x::AbstractArray{Float64}) = isnothing(e64) ? l(g, x) : l(g, x, e64)
40-
f(l, g::GNNGraph, x::CuArray) = isnothing(e64) ? l(g, x) : l(g, x, egpu)
39+
f(l, g::GNNGraph, x, e) = use_edge_feat ? l(g, x, e) : l(g, x)
4140

4241
loss(l, g::GNNGraph) = if outtype == :node
4342
sum(node_features(f(l, g)))
4443
elseif outtype == :edge
4544
sum(edge_features(f(l, g)))
4645
elseif outtype == :graph
4746
sum(graph_features(f(l, g)))
47+
elseif outtype == :node_edge
48+
gnew = f(l, g)
49+
sum(node_features(gnew)) + sum(edge_features(gnew))
4850
end
4951

50-
loss(l, g::GNNGraph, x) = sum(f(l, g, x))
51-
loss(l, g::GNNGraph, x, e) = sum(l(g, x, e))
52+
function loss(l, g::GNNGraph, x, e)
53+
y = f(l, g, x, e)
54+
if outtype == :node_edge
55+
return sum(y[1]) + sum(y[2])
56+
else
57+
return sum(y)
58+
end
59+
end
5260

5361

5462
# TEST OUTPUT
55-
y = f(l, g, x)
56-
@test eltype(y) == eltype(x)
57-
@test all(isfinite, y)
58-
if !isnothing(outsize)
59-
@test size(y) == outsize
63+
y = f(l, g, x, e)
64+
if outtype == :node_edge
65+
@assert y isa Tuple
66+
@test eltype(y[1]) == eltype(x)
67+
@test eltype(y[2]) == eltype(e)
68+
@test all(isfinite, y[1])
69+
@test all(isfinite, y[2])
70+
if !isnothing(outsize)
71+
@test size(y[1]) == outsize[1]
72+
@test size(y[2]) == outsize[2]
73+
end
74+
else
75+
@test eltype(y) == eltype(x)
76+
@test all(isfinite, y)
77+
if !isnothing(outsize)
78+
@test size(y) == outsize
79+
end
6080
end
6181

6282
# test same output on different graph formats
6383
gcoo = GNNGraph(g, graph_type=:coo)
64-
ycoo = f(l, gcoo, x)
65-
@test ycoo y
66-
84+
ycoo = f(l, gcoo, x, e)
85+
if outtype == :node_edge
86+
@test ycoo[1] y[1]
87+
@test ycoo[2] y[2]
88+
else
89+
@test ycoo y
90+
end
91+
6792
g′ = f(l, g)
6893
if outtype == :node
6994
@test g′.ndata.x y
7095
elseif outtype == :edge
7196
@test g′.edata.e y
7297
elseif outtype == :graph
7398
@test g′.gdata.u y
99+
elseif outtype == :node_edge
100+
@test g′.ndata.x y[1]
101+
@test g′.edata.e y[2]
74102
else
75103
@error "wrong outtype $outtype"
76104
end
77105
if test_gpu
78-
ygpu = f(lgpu, ggpu, xgpu)
79-
@test ygpu isa CuArray
80-
@test eltype(ygpu) == eltype(xgpu)
81-
@test Array(ygpu) y
106+
ygpu = f(lgpu, ggpu, xgpu, egpu)
107+
if outtype == :node_edge
108+
@test ygpu[1] isa CuArray
109+
@test eltype(ygpu[1]) == eltype(xgpu)
110+
@test Array(ygpu[1]) y[1]
111+
@test ygpu[2] isa CuArray
112+
@test eltype(ygpu[2]) == eltype(xgpu)
113+
@test Array(ygpu[2]) y[2]
114+
else
115+
@test ygpu isa CuArray
116+
@test eltype(ygpu) == eltype(xgpu)
117+
@test Array(ygpu) y
118+
end
82119
end
83120

84121

85122
# TEST x INPUT GRADIENT
86-
= gradient(x -> loss(l, g, x), x)[1]
87-
x̄_fd = FiniteDifferences.grad(fdm, x64 -> loss(l64, g64, x64), x64)[1]
123+
= gradient(x -> loss(l, g, x, e), x)[1]
124+
x̄_fd = FiniteDifferences.grad(fdm, x64 -> loss(l64, g64, x64, e64), x64)[1]
88125
@test eltype(x̄) == eltype(x)
89126
@test x̄_fd atol=atol rtol=rtol
90127

91128
if test_gpu
92-
x̄gpu = gradient(xgpu -> loss(lgpu, ggpu, xgpu), xgpu)[1]
129+
x̄gpu = gradient(xgpu -> loss(lgpu, ggpu, xgpu, egpu), xgpu)[1]
93130
@test x̄gpu isa CuArray
94131
@test eltype(x̄gpu) == eltype(x)
95132
@test Array(x̄gpu) x̄ atol=atol rtol=rtol
@@ -112,13 +149,13 @@ function test_layer(l, g::GNNGraph; atol = 1e-6, rtol = 1e-5,
112149
end
113150

114151

115-
# TEST LAYER GRADIENT - l(g, x)
116-
= gradient(l -> loss(l, g, x), l)[1]
117-
l̄_fd = FiniteDifferences.grad(fdm, l64 -> loss(l64, g64, x64), l64)[1]
152+
# TEST LAYER GRADIENT - l(g, x, e)
153+
= gradient(l -> loss(l, g, x, e), l)[1]
154+
l̄_fd = FiniteDifferences.grad(fdm, l64 -> loss(l64, g64, x64, e64), l64)[1]
118155
test_approx_structs(l, l̄, l̄_fd; atol, rtol, broken_grad_fields, exclude_grad_fields, verbose)
119156

120157
if test_gpu
121-
l̄gpu = gradient(lgpu -> loss(lgpu, ggpu, xgpu), lgpu)[1]
158+
l̄gpu = gradient(lgpu -> loss(lgpu, ggpu, xgpu, egpu), lgpu)[1]
122159
test_approx_structs(lgpu, l̄gpu, l̄; atol, rtol, broken_grad_fields, exclude_grad_fields, verbose)
123160
end
124161

0 commit comments

Comments
 (0)