Skip to content

Commit fc67808

Browse files
[GNNLux] fix tests (#468)
1 parent fb394d1 commit fc67808

File tree

3 files changed

+47
-89
lines changed

3 files changed

+47
-89
lines changed

GNNLux/test/layers/basic_tests.jl

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,13 @@
77
@test GNNLayer <: LuxCore.AbstractExplicitLayer
88
end
99

10+
@testset "GNNContainerLayer" begin
11+
@test GNNContainerLayer <: LuxCore.AbstractExplicitContainerLayer
12+
end
13+
1014
@testset "GNNChain" begin
1115
@test GNNChain <: LuxCore.AbstractExplicitContainerLayer{(:layers,)}
12-
@test GNNChain <: GNNContainerLayer
1316
c = GNNChain(GraphConv(3 => 5, relu), GCNConv(5 => 3))
14-
ps = LuxCore.initialparameters(rng, c)
15-
st = LuxCore.initialstates(rng, c)
16-
@test LuxCore.parameterlength(c) == LuxCore.parameterlength(ps)
17-
@test LuxCore.statelength(c) == LuxCore.statelength(st)
18-
y, st′ = c(g, x, ps, st)
19-
@test LuxCore.outputsize(c) == (3,)
20-
@test size(y) == (3, 10)
21-
loss = (x, ps) -> sum(first(c(g, x, ps, st)))
22-
@eval @test_gradients $loss $x $ps atol=1.0f-3 rtol=1.0f-3 skip_tracker=true skip_reverse_diff=true
17+
test_lux_layer(rng, c, g, x, outputsize=(3,), container=true)
2318
end
2419
end

GNNLux/test/layers/conv_tests.jl

Lines changed: 7 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -5,89 +5,32 @@
55

66
@testset "GCNConv" begin
77
l = GCNConv(3 => 5, relu)
8-
@test l isa GNNLayer
9-
ps = Lux.initialparameters(rng, l)
10-
st = Lux.initialstates(rng, l)
11-
@test Lux.parameterlength(l) == Lux.parameterlength(ps)
12-
@test Lux.statelength(l) == Lux.statelength(st)
13-
14-
y, _ = l(g, x, ps, st)
15-
@test Lux.outputsize(l) == (5,)
16-
@test size(y) == (5, 10)
17-
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
18-
@eval @test_gradients $loss $x $ps atol=1.0f-3 rtol=1.0f-3 skip_tracker=true
8+
test_lux_layer(rng, l, g, x, outputsize=(5,))
199
end
2010

2111
@testset "ChebConv" begin
2212
l = ChebConv(3 => 5, 2)
23-
@test l isa GNNLayer
24-
ps = Lux.initialparameters(rng, l)
25-
st = Lux.initialstates(rng, l)
26-
@test Lux.parameterlength(l) == Lux.parameterlength(ps)
27-
@test Lux.statelength(l) == Lux.statelength(st)
28-
29-
y, _ = l(g, x, ps, st)
30-
@test Lux.outputsize(l) == (5,)
31-
@test size(y) == (5, 10)
32-
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
33-
@eval @test_gradients $loss $x $ps atol=1.0f-3 rtol=1.0f-3 skip_tracker=true skip_reverse_diff=true
13+
test_lux_layer(rng, l, g, x, outputsize=(5,))
3414
end
3515

3616
@testset "GraphConv" begin
3717
l = GraphConv(3 => 5, relu)
38-
@test l isa GNNLayer
39-
ps = Lux.initialparameters(rng, l)
40-
st = Lux.initialstates(rng, l)
41-
@test Lux.parameterlength(l) == Lux.parameterlength(ps)
42-
@test Lux.statelength(l) == Lux.statelength(st)
43-
44-
y, _ = l(g, x, ps, st)
45-
@test Lux.outputsize(l) == (5,)
46-
@test size(y) == (5, 10)
47-
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
48-
@eval @test_gradients $loss $x $ps atol=1.0f-3 rtol=1.0f-3 skip_tracker=true
18+
test_lux_layer(rng, l, g, x, outputsize=(5,))
4919
end
5020

5121
@testset "AGNNConv" begin
5222
l = AGNNConv(init_beta=1.0f0)
53-
@test l isa GNNLayer
54-
ps = Lux.initialparameters(rng, l)
55-
st = Lux.initialstates(rng, l)
56-
@test Lux.parameterlength(ps) == 1
57-
@test Lux.parameterlength(l) == Lux.parameterlength(ps)
58-
@test Lux.statelength(l) == Lux.statelength(st)
59-
60-
y, _ = l(g, x, ps, st)
61-
@test size(y) == size(x)
62-
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
63-
@eval @test_gradients $loss $x $ps atol=1.0f-3 rtol=1.0f-3 skip_tracker=true skip_reverse_diff=true
23+
test_lux_layer(rng, l, g, x, sizey=(3,10))
6424
end
6525

6626
@testset "EdgeConv" begin
6727
nn = Chain(Dense(6 => 5, relu), Dense(5 => 5))
6828
l = EdgeConv(nn, aggr = +)
69-
@test l isa GNNContainerLayer
70-
ps = Lux.initialparameters(rng, l)
71-
st = Lux.initialstates(rng, l)
72-
@test Lux.parameterlength(l) == Lux.parameterlength(ps)
73-
@test Lux.statelength(l) == Lux.statelength(st)
74-
y, st′ = l(g, x, ps, st)
75-
@test size(y) == (5, 10)
76-
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
77-
@eval @test_gradients $loss $x $ps atol=1.0f-3 rtol=1.0f-3 skip_tracker=true skip_reverse_diff=true
29+
test_lux_layer(rng, l, g, x, sizey=(5,10), container=true)
7830
end
7931

8032
@testset "CGConv" begin
81-
l = CGConv(3 => 5, residual = true)
82-
@test l isa GNNContainerLayer
83-
ps = Lux.initialparameters(rng, l)
84-
st = Lux.initialstates(rng, l)
85-
@test Lux.parameterlength(l) == Lux.parameterlength(ps)
86-
@test Lux.statelength(l) == Lux.statelength(st)
87-
y, st′ = l(g, x, ps, st)
88-
@test size(y) == (5, 10)
89-
@test Lux.outputsize(l) == (5,)
90-
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
91-
@eval @test_gradients $loss $x $ps atol=1.0f-3 rtol=1.0f-3 skip_tracker=true
33+
l = CGConv(3 => 3, residual = true)
34+
test_lux_layer(rng, l, g, x, outputsize=(3,), container=true)
9235
end
9336
end

GNNLux/test/shared_testsetup.jl

Lines changed: 35 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,43 @@
22

33
import Reexport: @reexport
44

5+
@reexport using Test
56
@reexport using GNNLux
6-
@reexport using Lux, Functors
7-
@reexport using ComponentArrays, LuxCore, LuxTestUtils, Random, StableRNGs, Test,
8-
Zygote, Statistics
9-
@reexport using LuxTestUtils: @jet, @test_gradients, check_approx
10-
using MLDataDevices
11-
12-
# Some Helper Functions
13-
function get_default_rng(mode::String)
14-
dev = mode == "cpu" ? CPUDevice() :
15-
mode == "cuda" ? CUDADevice() : mode == "amdgpu" ? AMDGPUDevice() : nothing
16-
rng = default_device_rng(dev)
17-
return rng isa TaskLocalRNG ? copy(rng) : deepcopy(rng)
18-
end
7+
@reexport using Lux
8+
@reexport using StableRNGs
9+
@reexport using Random, Statistics
10+
11+
using LuxTestUtils: test_gradients, AutoReverseDiff, AutoTracker, AutoForwardDiff, AutoEnzyme
12+
13+
export test_lux_layer
1914

20-
export get_default_rng
15+
function test_lux_layer(rng::AbstractRNG, l, g::GNNGraph, x;
16+
outputsize=nothing, sizey=nothing, container=false,
17+
atol=1.0f-2, rtol=1.0f-2)
2118

22-
# export BACKEND_GROUP, MODES, cpu_testing, cuda_testing, amdgpu_testing
19+
if container
20+
@test l isa GNNContainerLayer
21+
else
22+
@test l isa GNNLayer
23+
end
24+
25+
ps = LuxCore.initialparameters(rng, l)
26+
st = LuxCore.initialstates(rng, l)
27+
@test LuxCore.parameterlength(l) == LuxCore.parameterlength(ps)
28+
@test LuxCore.statelength(l) == LuxCore.statelength(st)
29+
30+
y, st′ = l(g, x, ps, st)
31+
if outputsize !== nothing
32+
@test LuxCore.outputsize(l) == outputsize
33+
end
34+
if sizey !== nothing
35+
@test size(y) == sizey
36+
elseif outputsize !== nothing
37+
@test size(y) == (outputsize..., g.num_nodes)
38+
end
39+
40+
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
41+
test_gradients(loss, x, ps; atol, rtol, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoForwardDiff(), AutoEnzyme()])
42+
end
2343

2444
end

0 commit comments

Comments
 (0)