|
5 | 5 |
|
6 | 6 | @testset "GCNConv" begin
|
7 | 7 | l = GCNConv(3 => 5, relu)
|
8 |
| - @test l isa GNNLayer |
9 |
| - ps = Lux.initialparameters(rng, l) |
10 |
| - st = Lux.initialstates(rng, l) |
11 |
| - @test Lux.parameterlength(l) == Lux.parameterlength(ps) |
12 |
| - @test Lux.statelength(l) == Lux.statelength(st) |
13 |
| - |
14 |
| - y, _ = l(g, x, ps, st) |
15 |
| - @test Lux.outputsize(l) == (5,) |
16 |
| - @test size(y) == (5, 10) |
17 |
| - loss = (x, ps) -> sum(first(l(g, x, ps, st))) |
18 |
| - @eval @test_gradients $loss $x $ps atol=1.0f-3 rtol=1.0f-3 skip_tracker=true |
| 8 | + test_lux_layer(rng, l, g, x, outputsize=(5,)) |
19 | 9 | end
|
20 | 10 |
|
21 | 11 | @testset "ChebConv" begin
|
22 | 12 | l = ChebConv(3 => 5, 2)
|
23 |
| - @test l isa GNNLayer |
24 |
| - ps = Lux.initialparameters(rng, l) |
25 |
| - st = Lux.initialstates(rng, l) |
26 |
| - @test Lux.parameterlength(l) == Lux.parameterlength(ps) |
27 |
| - @test Lux.statelength(l) == Lux.statelength(st) |
28 |
| - |
29 |
| - y, _ = l(g, x, ps, st) |
30 |
| - @test Lux.outputsize(l) == (5,) |
31 |
| - @test size(y) == (5, 10) |
32 |
| - loss = (x, ps) -> sum(first(l(g, x, ps, st))) |
33 |
| - @eval @test_gradients $loss $x $ps atol=1.0f-3 rtol=1.0f-3 skip_tracker=true skip_reverse_diff=true |
| 13 | + test_lux_layer(rng, l, g, x, outputsize=(5,)) |
34 | 14 | end
|
35 | 15 |
|
36 | 16 | @testset "GraphConv" begin
|
37 | 17 | l = GraphConv(3 => 5, relu)
|
38 |
| - @test l isa GNNLayer |
39 |
| - ps = Lux.initialparameters(rng, l) |
40 |
| - st = Lux.initialstates(rng, l) |
41 |
| - @test Lux.parameterlength(l) == Lux.parameterlength(ps) |
42 |
| - @test Lux.statelength(l) == Lux.statelength(st) |
43 |
| - |
44 |
| - y, _ = l(g, x, ps, st) |
45 |
| - @test Lux.outputsize(l) == (5,) |
46 |
| - @test size(y) == (5, 10) |
47 |
| - loss = (x, ps) -> sum(first(l(g, x, ps, st))) |
48 |
| - @eval @test_gradients $loss $x $ps atol=1.0f-3 rtol=1.0f-3 skip_tracker=true |
| 18 | + test_lux_layer(rng, l, g, x, outputsize=(5,)) |
49 | 19 | end
|
50 | 20 |
|
51 | 21 | @testset "AGNNConv" begin
|
52 | 22 | l = AGNNConv(init_beta=1.0f0)
|
53 |
| - @test l isa GNNLayer |
54 |
| - ps = Lux.initialparameters(rng, l) |
55 |
| - st = Lux.initialstates(rng, l) |
56 |
| - @test Lux.parameterlength(ps) == 1 |
57 |
| - @test Lux.parameterlength(l) == Lux.parameterlength(ps) |
58 |
| - @test Lux.statelength(l) == Lux.statelength(st) |
59 |
| - |
60 |
| - y, _ = l(g, x, ps, st) |
61 |
| - @test size(y) == size(x) |
62 |
| - loss = (x, ps) -> sum(first(l(g, x, ps, st))) |
63 |
| - @eval @test_gradients $loss $x $ps atol=1.0f-3 rtol=1.0f-3 skip_tracker=true skip_reverse_diff=true |
| 23 | + test_lux_layer(rng, l, g, x, sizey=(3,10)) |
64 | 24 | end
|
65 | 25 |
|
66 | 26 | @testset "EdgeConv" begin
|
67 | 27 | nn = Chain(Dense(6 => 5, relu), Dense(5 => 5))
|
68 | 28 | l = EdgeConv(nn, aggr = +)
|
69 |
| - @test l isa GNNContainerLayer |
70 |
| - ps = Lux.initialparameters(rng, l) |
71 |
| - st = Lux.initialstates(rng, l) |
72 |
| - @test Lux.parameterlength(l) == Lux.parameterlength(ps) |
73 |
| - @test Lux.statelength(l) == Lux.statelength(st) |
74 |
| - y, st′ = l(g, x, ps, st) |
75 |
| - @test size(y) == (5, 10) |
76 |
| - loss = (x, ps) -> sum(first(l(g, x, ps, st))) |
77 |
| - @eval @test_gradients $loss $x $ps atol=1.0f-3 rtol=1.0f-3 skip_tracker=true skip_reverse_diff=true |
| 29 | + test_lux_layer(rng, l, g, x, sizey=(5,10), container=true) |
78 | 30 | end
|
79 | 31 |
|
80 | 32 | @testset "CGConv" begin
|
81 |
| - l = CGConv(3 => 5, residual = true) |
82 |
| - @test l isa GNNContainerLayer |
83 |
| - ps = Lux.initialparameters(rng, l) |
84 |
| - st = Lux.initialstates(rng, l) |
85 |
| - @test Lux.parameterlength(l) == Lux.parameterlength(ps) |
86 |
| - @test Lux.statelength(l) == Lux.statelength(st) |
87 |
| - y, st′ = l(g, x, ps, st) |
88 |
| - @test size(y) == (5, 10) |
89 |
| - @test Lux.outputsize(l) == (5,) |
90 |
| - loss = (x, ps) -> sum(first(l(g, x, ps, st))) |
91 |
| - @eval @test_gradients $loss $x $ps atol=1.0f-3 rtol=1.0f-3 skip_tracker=true |
| 33 | + l = CGConv(3 => 3, residual = true) |
| 34 | + test_lux_layer(rng, l, g, x, outputsize=(3,), container=true) |
92 | 35 | end
|
93 | 36 | end
|
0 commit comments