|
1 | 1 | @testmodule TemporalConvTestModule begin |
2 | 2 | using GraphNeuralNetworks |
3 | | - export in_channel, out_channel, N, timesteps, g, tg, RTOL_LOW, RTOL_HIGH, ATOL_LOW |
| 3 | + using Statistics |
| 4 | + export in_channel, out_channel, N, timesteps, g, tg, cell_loss, |
| 5 | + RTOL_LOW, ATOL_LOW, RTOL_HIGH |
4 | 6 |
|
5 | 7 | RTOL_LOW = 1e-2 |
6 | | - RTOL_HIGH = 1e-5 |
7 | 8 | ATOL_LOW = 1e-3 |
| 9 | + RTOL_HIGH = 1e-5 |
8 | 10 |
|
9 | 11 | in_channel = 3 |
10 | 12 | out_channel = 5 |
11 | 13 | N = 4 |
12 | 14 | timesteps = 5 |
13 | 15 |
|
| 16 | + cell_loss(cell, g, x...) = mean(cell(g, x...)[1]) |
| 17 | + |
14 | 18 | g = GNNGraph(rand_graph(N, 8), |
15 | 19 | ndata = rand(Float32, in_channel, N), |
16 | 20 | graph_type = :coo) |
|
22 | 26 | @testitem "TGCNCell" setup=[TemporalConvTestModule, TestModule] begin |
23 | 27 | using .TemporalConvTestModule, .TestModule |
24 | 28 | cell = GraphNeuralNetworks.TGCNCell(in_channel => out_channel) |
25 | | - h = cell(g, g.x) |
| 29 | + y, h = cell(g, g.x) |
| 30 | + @test y === h |
26 | 31 | @test size(h) == (out_channel, g.num_nodes) |
27 | | - test_gradients(cell, g, g.x, rtol = RTOL_HIGH) |
| 32 | + # with no initial state |
| 33 | + test_gradients(cell, g, g.x, loss=cell_loss, rtol=RTOL_HIGH) |
| 34 | + # with initial state |
| 35 | + test_gradients(cell, g, g.x, h, loss=cell_loss, rtol=RTOL_HIGH) |
28 | 36 | end |
29 | 37 |
|
30 | 38 | @testitem "TGCN" setup=[TemporalConvTestModule, TestModule] begin |
31 | 39 | using .TemporalConvTestModule, .TestModule |
32 | | - tgcn = TGCN(in_channel => out_channel) |
| 40 | + layer = TGCN(in_channel => out_channel) |
33 | 41 | x = rand(Float32, in_channel, timesteps, g.num_nodes) |
34 | | - h = tgcn(g, x) |
35 | | - @test size(h) == (out_channel, timesteps, g.num_nodes) |
36 | | - test_gradients(tgcn, g, x, rtol = RTOL_HIGH) |
37 | | - test_gradients(tgcn, g, x, h[:,1,:], rtol = RTOL_HIGH) |
38 | | - |
39 | | - # model = GNNChain(TGCN(in_channel => out_channel), Dense(out_channel, 1)) |
40 | | - # @test size(model(g1, g1.ndata.x)) == (1, N) |
41 | | - # @test model(g1) isa GNNGraph |
| 42 | + state0 = rand(Float32, out_channel, g.num_nodes) |
| 43 | + y = layer(g, x) |
| 44 | + @test layer isa GNNRecurrence |
| 45 | + @test size(y) == (out_channel, timesteps, g.num_nodes) |
| 46 | + # with no initial state |
| 47 | + test_gradients(layer, g, x, rtol = RTOL_HIGH) |
| 48 | + # with initial state |
| 49 | + test_gradients(layer, g, x, state0, rtol = RTOL_HIGH) |
| 50 | + |
| 51 | + # interplay with GNNChain |
| 52 | + model = GNNChain(TGCN(in_channel => out_channel), Dense(out_channel, 1)) |
| 53 | + y = model(g, x) |
| 54 | + @test size(y) == (1, timesteps, g.num_nodes) |
| 55 | + test_gradients(model, g, x, rtol = RTOL_HIGH, atol = ATOL_LOW) |
42 | 56 | end |
43 | 57 |
|
44 | | -# @testitem "A3TGCN" setup=[TemporalConvTestModule, TestModule] begin |
45 | | -# using .TemporalConvTestModule, .TestModule |
46 | | -# a3tgcn = A3TGCN(in_channel => out_channel) |
47 | | -# @test size(Flux.gradient(x -> sum(a3tgcn(g1, x)), g1.ndata.x)[1]) == (in_channel, N) |
48 | | -# model = GNNChain(A3TGCN(in_channel => out_channel), Dense(out_channel, 1)) |
49 | | -# @test size(model(g1, g1.ndata.x)) == (1, N) |
50 | | -# @test model(g1) isa GNNGraph |
51 | | -# end |
| 58 | +@testitem "GConvLSTMCell" setup=[TemporalConvTestModule, TestModule] begin |
| 59 | + using .TemporalConvTestModule, .TestModule |
| 60 | + cell = GConvLSTMCell(in_channel => out_channel, 2) |
| 61 | + y, (h, c) = cell(g, g.x) |
| 62 | + @test y === h |
| 63 | + @test size(h) == (out_channel, g.num_nodes) |
| 64 | + @test size(c) == (out_channel, g.num_nodes) |
| 65 | + # with no initial state |
| 66 | + test_gradients(cell, g, g.x, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW) |
| 67 | + # with initial state |
| 68 | + test_gradients(cell, g, g.x, (h, c), loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW) |
| 69 | +end |
52 | 70 |
|
53 | | -# @testitem "GConvLSTMCell" setup=[TemporalConvTestModule, TestModule] begin |
54 | | -# using .TemporalConvTestModule, .TestModule |
55 | | -# gconvlstm = GraphNeuralNetworks.GConvLSTMCell(in_channel => out_channel, 2, g1.num_nodes) |
56 | | -# (h, c), h = gconvlstm(gconvlstm.state0, g1, g1.ndata.x) |
57 | | -# @test size(h) == (out_channel, N) |
58 | | -# @test size(c) == (out_channel, N) |
59 | | -# end |
| 71 | +@testitem "GConvLSTM" setup=[TemporalConvTestModule, TestModule] begin |
| 72 | + using .TemporalConvTestModule, .TestModule |
| 73 | + layer = GConvLSTM(in_channel => out_channel, 2) |
| 74 | + @test layer isa GNNRecurrence |
| 75 | + x = rand(Float32, in_channel, timesteps, g.num_nodes) |
| 76 | + state0 = (rand(Float32, out_channel, g.num_nodes), rand(Float32, out_channel, g.num_nodes)) |
| 77 | + y = layer(g, x) |
| 78 | + @test size(y) == (out_channel, timesteps, g.num_nodes) |
| 79 | + # with no initial state |
| 80 | + test_gradients(layer, g, x, rtol=RTOL_LOW, atol=ATOL_LOW) |
| 81 | + # with initial state |
| 82 | + test_gradients(layer, g, x, state0, rtol=RTOL_LOW, atol=ATOL_LOW) |
| 83 | + |
| 84 | + # interplay with GNNChain |
| 85 | + model = GNNChain(GConvLSTM(in_channel => out_channel, 2), Dense(out_channel, 1)) |
| 86 | + y = model(g, x) |
| 87 | + @test size(y) == (1, timesteps, g.num_nodes) |
| 88 | + test_gradients(model, g, x, rtol = RTOL_LOW, atol = ATOL_LOW) |
| 89 | +end |
60 | 90 |
|
61 | | -# @testitem "GConvLSTM" setup=[TemporalConvTestModule, TestModule] begin |
62 | | -# using .TemporalConvTestModule, .TestModule |
63 | | -# gconvlstm = GConvLSTM(in_channel => out_channel, 2, g1.num_nodes) |
64 | | -# @test size(Flux.gradient(x -> sum(gconvlstm(g1, x)), g1.ndata.x)[1]) == (in_channel, N) |
65 | | -# model = GNNChain(GConvLSTM(in_channel => out_channel, 2, g1.num_nodes), Dense(out_channel, 1)) |
66 | | -# end |
| 91 | +@testitem "GConvGRUCell" setup=[TemporalConvTestModule, TestModule] begin |
| 92 | + using .TemporalConvTestModule, .TestModule |
| 93 | + cell = GConvGRUCell(in_channel => out_channel, 2) |
| 94 | + y, h = cell(g, g.x) |
| 95 | + @test y === h |
| 96 | + @test size(h) == (out_channel, g.num_nodes) |
| 97 | + # with no initial state |
| 98 | + test_gradients(cell, g, g.x, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW) |
| 99 | + # with initial state |
| 100 | + test_gradients(cell, g, g.x, h, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW) |
| 101 | +end |
67 | 102 |
|
68 | | -# @testitem "GConvGRUCell" setup=[TemporalConvTestModule, TestModule] begin |
69 | | -# using .TemporalConvTestModule, .TestModule |
70 | | -# gconvlstm = GraphNeuralNetworks.GConvGRUCell(in_channel => out_channel, 2, g1.num_nodes) |
71 | | -# h, h = gconvlstm(gconvlstm.state0, g1, g1.ndata.x) |
72 | | -# @test size(h) == (out_channel, N) |
73 | | -# end |
74 | 103 |
|
75 | | -# @testitem "GConvGRU" setup=[TemporalConvTestModule, TestModule] begin |
76 | | -# using .TemporalConvTestModule, .TestModule |
77 | | -# gconvlstm = GConvGRU(in_channel => out_channel, 2, g1.num_nodes) |
78 | | -# @test size(Flux.gradient(x -> sum(gconvlstm(g1, x)), g1.ndata.x)[1]) == (in_channel, N) |
79 | | -# model = GNNChain(GConvGRU(in_channel => out_channel, 2, g1.num_nodes), Dense(out_channel, 1)) |
80 | | -# @test size(model(g1, g1.ndata.x)) == (1, N) |
81 | | -# @test model(g1) isa GNNGraph |
82 | | -# end |
| 104 | +@testitem "GConvGRU" setup=[TemporalConvTestModule, TestModule] begin |
| 105 | + using .TemporalConvTestModule, .TestModule |
| 106 | + layer = GConvGRU(in_channel => out_channel, 2) |
| 107 | + @test layer isa GNNRecurrence |
| 108 | + x = rand(Float32, in_channel, timesteps, g.num_nodes) |
| 109 | + state0 = rand(Float32, out_channel, g.num_nodes) |
| 110 | + y = layer(g, x) |
| 111 | + @test size(y) == (out_channel, timesteps, g.num_nodes) |
| 112 | + # with no initial state |
| 113 | + test_gradients(layer, g, x, rtol=RTOL_LOW, atol=ATOL_LOW) |
| 114 | + # with initial state |
| 115 | + test_gradients(layer, g, x, state0, rtol=RTOL_LOW, atol=ATOL_LOW) |
| 116 | + |
| 117 | + # interplay with GNNChain |
| 118 | + model = GNNChain(GConvGRU(in_channel => out_channel, 2), Dense(out_channel, 1)) |
| 119 | + y = model(g, x) |
| 120 | + @test size(y) == (1, timesteps, g.num_nodes) |
| 121 | + test_gradients(model, g, x, rtol = RTOL_LOW, atol = ATOL_LOW) |
| 122 | +end |
83 | 123 |
|
84 | | -# @testitem "DCGRU" setup=[TemporalConvTestModule, TestModule] begin |
85 | | -# using .TemporalConvTestModule, .TestModule |
86 | | -# dcgru = DCGRU(in_channel => out_channel, 2, g1.num_nodes) |
87 | | -# @test size(Flux.gradient(x -> sum(dcgru(g1, x)), g1.ndata.x)[1]) == (in_channel, N) |
88 | | -# model = GNNChain(DCGRU(in_channel => out_channel, 2, g1.num_nodes), Dense(out_channel, 1)) |
89 | | -# @test size(model(g1, g1.ndata.x)) == (1, N) |
90 | | -# @test model(g1) isa GNNGraph |
91 | | -# end |
| 124 | +@testitem "DCGRUCell" setup=[TemporalConvTestModule, TestModule] begin |
| 125 | + using .TemporalConvTestModule, .TestModule |
| 126 | + cell = DCGRUCell(in_channel => out_channel, 2) |
| 127 | + y, h = cell(g, g.x) |
| 128 | + @test y === h |
| 129 | + @test size(h) == (out_channel, g.num_nodes) |
| 130 | + # with no initial state |
| 131 | + test_gradients(cell, g, g.x, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW) |
| 132 | + # with initial state |
| 133 | + test_gradients(cell, g, g.x, h, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW) |
| 134 | +end |
92 | 135 |
|
93 | | -# @testitem "EvolveGCNO" setup=[TemporalConvTestModule, TestModule] begin |
94 | | -# using .TemporalConvTestModule, .TestModule |
95 | | -# evolvegcno = EvolveGCNO(in_channel => out_channel) |
96 | | -# @test length(Flux.gradient(x -> sum(sum(evolvegcno(tg, x))), tg.ndata.x)[1]) == S |
97 | | -# @test size(evolvegcno(tg, tg.ndata.x)[1]) == (out_channel, N) |
98 | | -# end |
| 136 | +@testitem "DCGRU" setup=[TemporalConvTestModule, TestModule] begin |
| 137 | + using .TemporalConvTestModule, .TestModule |
| 138 | + layer = DCGRU(in_channel => out_channel, 2) |
| 139 | + @test layer isa GNNRecurrence |
| 140 | + x = rand(Float32, in_channel, timesteps, g.num_nodes) |
| 141 | + state0 = rand(Float32, out_channel, g.num_nodes) |
| 142 | + y = layer(g, x) |
| 143 | + @test size(y) == (out_channel, timesteps, g.num_nodes) |
| 144 | + # with no initial state |
| 145 | + test_gradients(layer, g, x, rtol=RTOL_LOW, atol=ATOL_LOW) |
| 146 | + # with initial state |
| 147 | + test_gradients(layer, g, x, state0, rtol=RTOL_LOW, atol=ATOL_LOW) |
| 148 | + |
| 149 | + # interplay with GNNChain |
| 150 | + model = GNNChain(DCGRU(in_channel => out_channel, 2), Dense(out_channel, 1)) |
| 151 | + y = model(g, x) |
| 152 | + @test size(y) == (1, timesteps, g.num_nodes) |
| 153 | + test_gradients(model, g, x, rtol = RTOL_LOW, atol = ATOL_LOW) |
| 154 | +end |
| 155 | + |
| 156 | +@testitem "EvolveGCNOCell" setup=[TemporalConvTestModule, TestModule] begin |
| 157 | + using .TemporalConvTestModule, .TestModule |
| 158 | + cell = EvolveGCNOCell(in_channel => out_channel) |
| 159 | + y, state = cell(g, g.x) |
| 160 | + @test size(y) == (out_channel, g.num_nodes) |
| 161 | + # with no initial state |
| 162 | + test_gradients(cell, g, g.x, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW) |
| 163 | + # with initial state |
| 164 | + test_gradients(cell, g, g.x, state, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW) |
| 165 | +end |
| 166 | + |
| 167 | +@testitem "EvolveGCNO" setup=[TemporalConvTestModule, TestModule] begin |
| 168 | + using .TemporalConvTestModule, .TestModule |
| 169 | + layer = EvolveGCNO(in_channel => out_channel) |
| 170 | + @test layer isa GNNRecurrence |
| 171 | + x = rand(Float32, in_channel, timesteps, g.num_nodes) |
| 172 | + state0 = Flux.initialstates(layer) |
| 173 | + y = layer(g, x) |
| 174 | + @test size(y) == (out_channel, timesteps, g.num_nodes) |
| 175 | + # with no initial state |
| 176 | + test_gradients(layer, g, x, rtol=RTOL_LOW, atol=ATOL_LOW) |
| 177 | + # with initial state |
| 178 | + test_gradients(layer, g, x, state0, rtol=RTOL_LOW, atol=ATOL_LOW) |
| 179 | + |
| 180 | + # interplay with GNNChain |
| 181 | + model = GNNChain(EvolveGCNO(in_channel => out_channel), Dense(out_channel, 1)) |
| 182 | + y = model(g, x) |
| 183 | + @test size(y) == (1, timesteps, g.num_nodes) |
| 184 | + test_gradients(model, g, x, rtol=RTOL_LOW, atol=ATOL_LOW) |
| 185 | +end |
99 | 186 |
|
100 | 187 | # @testitem "GINConv" setup=[TemporalConvTestModule, TestModule] begin |
101 | 188 | # using .TemporalConvTestModule, .TestModule |
|
0 commit comments