Skip to content

Commit 6a23a70

Browse files
tests
1 parent d762110 commit 6a23a70

File tree

3 files changed

+152
-62
lines changed

3 files changed

+152
-62
lines changed

GraphNeuralNetworks/src/layers/temporalconv.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -696,6 +696,8 @@ function Flux.initialstates(cell::EvolveGCNOCell)
696696
return (; weight, lstm)
697697
end
698698

699+
(cell::EvolveGCNOCell)(g::GNNGraph, x::AbstractMatrix) = cell(g, x, initialstates(cell))
700+
699701
function (cell::EvolveGCNOCell)(g::GNNGraph, x::AbstractMatrix, state)
700702
weight, state_lstm = cell.lstm(state.weight, state.lstm)
701703
x = cell.conv(g, x, conv_weight = reshape(weight, (cell.out, cell.in)))

GraphNeuralNetworks/test/layers/temporalconv.jl

Lines changed: 149 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,20 @@
11
@testmodule TemporalConvTestModule begin
22
using GraphNeuralNetworks
3-
export in_channel, out_channel, N, timesteps, g, tg, RTOL_LOW, RTOL_HIGH, ATOL_LOW
3+
using Statistics
4+
export in_channel, out_channel, N, timesteps, g, tg, cell_loss,
5+
RTOL_LOW, ATOL_LOW, RTOL_HIGH
46

57
RTOL_LOW = 1e-2
6-
RTOL_HIGH = 1e-5
78
ATOL_LOW = 1e-3
9+
RTOL_HIGH = 1e-5
810

911
in_channel = 3
1012
out_channel = 5
1113
N = 4
1214
timesteps = 5
1315

16+
cell_loss(cell, g, x...) = mean(cell(g, x...)[1])
17+
1418
g = GNNGraph(rand_graph(N, 8),
1519
ndata = rand(Float32, in_channel, N),
1620
graph_type = :coo)
@@ -22,80 +26,163 @@ end
2226
@testitem "TGCNCell" setup=[TemporalConvTestModule, TestModule] begin
2327
using .TemporalConvTestModule, .TestModule
2428
cell = GraphNeuralNetworks.TGCNCell(in_channel => out_channel)
25-
h = cell(g, g.x)
29+
y, h = cell(g, g.x)
30+
@test y === h
2631
@test size(h) == (out_channel, g.num_nodes)
27-
test_gradients(cell, g, g.x, rtol = RTOL_HIGH)
32+
# with no initial state
33+
test_gradients(cell, g, g.x, loss=cell_loss, rtol=RTOL_HIGH)
34+
# with initial state
35+
test_gradients(cell, g, g.x, h, loss=cell_loss, rtol=RTOL_HIGH)
2836
end
2937

3038
@testitem "TGCN" setup=[TemporalConvTestModule, TestModule] begin
3139
using .TemporalConvTestModule, .TestModule
32-
tgcn = TGCN(in_channel => out_channel)
40+
layer = TGCN(in_channel => out_channel)
3341
x = rand(Float32, in_channel, timesteps, g.num_nodes)
34-
h = tgcn(g, x)
35-
@test size(h) == (out_channel, timesteps, g.num_nodes)
36-
test_gradients(tgcn, g, x, rtol = RTOL_HIGH)
37-
test_gradients(tgcn, g, x, h[:,1,:], rtol = RTOL_HIGH)
38-
39-
# model = GNNChain(TGCN(in_channel => out_channel), Dense(out_channel, 1))
40-
# @test size(model(g1, g1.ndata.x)) == (1, N)
41-
# @test model(g1) isa GNNGraph
42+
state0 = rand(Float32, out_channel, g.num_nodes)
43+
y = layer(g, x)
44+
@test layer isa GNNRecurrence
45+
@test size(y) == (out_channel, timesteps, g.num_nodes)
46+
# with no initial state
47+
test_gradients(layer, g, x, rtol = RTOL_HIGH)
48+
# with initial state
49+
test_gradients(layer, g, x, state0, rtol = RTOL_HIGH)
50+
51+
# interplay with GNNChain
52+
model = GNNChain(TGCN(in_channel => out_channel), Dense(out_channel, 1))
53+
y = model(g, x)
54+
@test size(y) == (1, timesteps, g.num_nodes)
55+
test_gradients(model, g, x, rtol = RTOL_HIGH, atol = ATOL_LOW)
4256
end
4357

44-
# @testitem "A3TGCN" setup=[TemporalConvTestModule, TestModule] begin
45-
# using .TemporalConvTestModule, .TestModule
46-
# a3tgcn = A3TGCN(in_channel => out_channel)
47-
# @test size(Flux.gradient(x -> sum(a3tgcn(g1, x)), g1.ndata.x)[1]) == (in_channel, N)
48-
# model = GNNChain(A3TGCN(in_channel => out_channel), Dense(out_channel, 1))
49-
# @test size(model(g1, g1.ndata.x)) == (1, N)
50-
# @test model(g1) isa GNNGraph
51-
# end
58+
@testitem "GConvLSTMCell" setup=[TemporalConvTestModule, TestModule] begin
59+
using .TemporalConvTestModule, .TestModule
60+
cell = GConvLSTMCell(in_channel => out_channel, 2)
61+
y, (h, c) = cell(g, g.x)
62+
@test y === h
63+
@test size(h) == (out_channel, g.num_nodes)
64+
@test size(c) == (out_channel, g.num_nodes)
65+
# with no initial state
66+
test_gradients(cell, g, g.x, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW)
67+
# with initial state
68+
test_gradients(cell, g, g.x, (h, c), loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW)
69+
end
5270

53-
# @testitem "GConvLSTMCell" setup=[TemporalConvTestModule, TestModule] begin
54-
# using .TemporalConvTestModule, .TestModule
55-
# gconvlstm = GraphNeuralNetworks.GConvLSTMCell(in_channel => out_channel, 2, g1.num_nodes)
56-
# (h, c), h = gconvlstm(gconvlstm.state0, g1, g1.ndata.x)
57-
# @test size(h) == (out_channel, N)
58-
# @test size(c) == (out_channel, N)
59-
# end
71+
@testitem "GConvLSTM" setup=[TemporalConvTestModule, TestModule] begin
72+
using .TemporalConvTestModule, .TestModule
73+
layer = GConvLSTM(in_channel => out_channel, 2)
74+
@test layer isa GNNRecurrence
75+
x = rand(Float32, in_channel, timesteps, g.num_nodes)
76+
state0 = (rand(Float32, out_channel, g.num_nodes), rand(Float32, out_channel, g.num_nodes))
77+
y = layer(g, x)
78+
@test size(y) == (out_channel, timesteps, g.num_nodes)
79+
# with no initial state
80+
test_gradients(layer, g, x, rtol=RTOL_LOW, atol=ATOL_LOW)
81+
# with initial state
82+
test_gradients(layer, g, x, state0, rtol=RTOL_LOW, atol=ATOL_LOW)
83+
84+
# interplay with GNNChain
85+
model = GNNChain(GConvLSTM(in_channel => out_channel, 2), Dense(out_channel, 1))
86+
y = model(g, x)
87+
@test size(y) == (1, timesteps, g.num_nodes)
88+
test_gradients(model, g, x, rtol = RTOL_LOW, atol = ATOL_LOW)
89+
end
6090

61-
# @testitem "GConvLSTM" setup=[TemporalConvTestModule, TestModule] begin
62-
# using .TemporalConvTestModule, .TestModule
63-
# gconvlstm = GConvLSTM(in_channel => out_channel, 2, g1.num_nodes)
64-
# @test size(Flux.gradient(x -> sum(gconvlstm(g1, x)), g1.ndata.x)[1]) == (in_channel, N)
65-
# model = GNNChain(GConvLSTM(in_channel => out_channel, 2, g1.num_nodes), Dense(out_channel, 1))
66-
# end
91+
@testitem "GConvGRUCell" setup=[TemporalConvTestModule, TestModule] begin
92+
using .TemporalConvTestModule, .TestModule
93+
cell = GConvGRUCell(in_channel => out_channel, 2)
94+
y, h = cell(g, g.x)
95+
@test y === h
96+
@test size(h) == (out_channel, g.num_nodes)
97+
# with no initial state
98+
test_gradients(cell, g, g.x, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW)
99+
# with initial state
100+
test_gradients(cell, g, g.x, h, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW)
101+
end
67102

68-
# @testitem "GConvGRUCell" setup=[TemporalConvTestModule, TestModule] begin
69-
# using .TemporalConvTestModule, .TestModule
70-
# gconvlstm = GraphNeuralNetworks.GConvGRUCell(in_channel => out_channel, 2, g1.num_nodes)
71-
# h, h = gconvlstm(gconvlstm.state0, g1, g1.ndata.x)
72-
# @test size(h) == (out_channel, N)
73-
# end
74103

75-
# @testitem "GConvGRU" setup=[TemporalConvTestModule, TestModule] begin
76-
# using .TemporalConvTestModule, .TestModule
77-
# gconvlstm = GConvGRU(in_channel => out_channel, 2, g1.num_nodes)
78-
# @test size(Flux.gradient(x -> sum(gconvlstm(g1, x)), g1.ndata.x)[1]) == (in_channel, N)
79-
# model = GNNChain(GConvGRU(in_channel => out_channel, 2, g1.num_nodes), Dense(out_channel, 1))
80-
# @test size(model(g1, g1.ndata.x)) == (1, N)
81-
# @test model(g1) isa GNNGraph
82-
# end
104+
@testitem "GConvGRU" setup=[TemporalConvTestModule, TestModule] begin
105+
using .TemporalConvTestModule, .TestModule
106+
layer = GConvGRU(in_channel => out_channel, 2)
107+
@test layer isa GNNRecurrence
108+
x = rand(Float32, in_channel, timesteps, g.num_nodes)
109+
state0 = rand(Float32, out_channel, g.num_nodes)
110+
y = layer(g, x)
111+
@test size(y) == (out_channel, timesteps, g.num_nodes)
112+
# with no initial state
113+
test_gradients(layer, g, x, rtol=RTOL_LOW, atol=ATOL_LOW)
114+
# with initial state
115+
test_gradients(layer, g, x, state0, rtol=RTOL_LOW, atol=ATOL_LOW)
116+
117+
# interplay with GNNChain
118+
model = GNNChain(GConvGRU(in_channel => out_channel, 2), Dense(out_channel, 1))
119+
y = model(g, x)
120+
@test size(y) == (1, timesteps, g.num_nodes)
121+
test_gradients(model, g, x, rtol = RTOL_LOW, atol = ATOL_LOW)
122+
end
83123

84-
# @testitem "DCGRU" setup=[TemporalConvTestModule, TestModule] begin
85-
# using .TemporalConvTestModule, .TestModule
86-
# dcgru = DCGRU(in_channel => out_channel, 2, g1.num_nodes)
87-
# @test size(Flux.gradient(x -> sum(dcgru(g1, x)), g1.ndata.x)[1]) == (in_channel, N)
88-
# model = GNNChain(DCGRU(in_channel => out_channel, 2, g1.num_nodes), Dense(out_channel, 1))
89-
# @test size(model(g1, g1.ndata.x)) == (1, N)
90-
# @test model(g1) isa GNNGraph
91-
# end
124+
@testitem "DCGRUCell" setup=[TemporalConvTestModule, TestModule] begin
125+
using .TemporalConvTestModule, .TestModule
126+
cell = DCGRUCell(in_channel => out_channel, 2)
127+
y, h = cell(g, g.x)
128+
@test y === h
129+
@test size(h) == (out_channel, g.num_nodes)
130+
# with no initial state
131+
test_gradients(cell, g, g.x, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW)
132+
# with initial state
133+
test_gradients(cell, g, g.x, h, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW)
134+
end
92135

93-
# @testitem "EvolveGCNO" setup=[TemporalConvTestModule, TestModule] begin
94-
# using .TemporalConvTestModule, .TestModule
95-
# evolvegcno = EvolveGCNO(in_channel => out_channel)
96-
# @test length(Flux.gradient(x -> sum(sum(evolvegcno(tg, x))), tg.ndata.x)[1]) == S
97-
# @test size(evolvegcno(tg, tg.ndata.x)[1]) == (out_channel, N)
98-
# end
136+
@testitem "DCGRU" setup=[TemporalConvTestModule, TestModule] begin
137+
using .TemporalConvTestModule, .TestModule
138+
layer = DCGRU(in_channel => out_channel, 2)
139+
@test layer isa GNNRecurrence
140+
x = rand(Float32, in_channel, timesteps, g.num_nodes)
141+
state0 = rand(Float32, out_channel, g.num_nodes)
142+
y = layer(g, x)
143+
@test size(y) == (out_channel, timesteps, g.num_nodes)
144+
# with no initial state
145+
test_gradients(layer, g, x, rtol=RTOL_LOW, atol=ATOL_LOW)
146+
# with initial state
147+
test_gradients(layer, g, x, state0, rtol=RTOL_LOW, atol=ATOL_LOW)
148+
149+
# interplay with GNNChain
150+
model = GNNChain(DCGRU(in_channel => out_channel, 2), Dense(out_channel, 1))
151+
y = model(g, x)
152+
@test size(y) == (1, timesteps, g.num_nodes)
153+
test_gradients(model, g, x, rtol = RTOL_LOW, atol = ATOL_LOW)
154+
end
155+
156+
@testitem "EvolveGCNOCell" setup=[TemporalConvTestModule, TestModule] begin
157+
using .TemporalConvTestModule, .TestModule
158+
cell = EvolveGCNOCell(in_channel => out_channel)
159+
y, state = cell(g, g.x)
160+
@test size(y) == (out_channel, g.num_nodes)
161+
# with no initial state
162+
test_gradients(cell, g, g.x, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW)
163+
# with initial state
164+
test_gradients(cell, g, g.x, state, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW)
165+
end
166+
167+
@testitem "EvolveGCNO" setup=[TemporalConvTestModule, TestModule] begin
168+
using .TemporalConvTestModule, .TestModule
169+
layer = EvolveGCNO(in_channel => out_channel)
170+
@test layer isa GNNRecurrence
171+
x = rand(Float32, in_channel, timesteps, g.num_nodes)
172+
state0 = Flux.initialstates(layer)
173+
y = layer(g, x)
174+
@test size(y) == (out_channel, timesteps, g.num_nodes)
175+
# with no initial state
176+
test_gradients(layer, g, x, rtol=RTOL_LOW, atol=ATOL_LOW)
177+
# with initial state
178+
test_gradients(layer, g, x, state0, rtol=RTOL_LOW, atol=ATOL_LOW)
179+
180+
# interplay with GNNChain
181+
model = GNNChain(EvolveGCNO(in_channel => out_channel), Dense(out_channel, 1))
182+
y = model(g, x)
183+
@test size(y) == (1, timesteps, g.num_nodes)
184+
test_gradients(model, g, x, rtol=RTOL_LOW, atol=ATOL_LOW)
185+
end
99186

100187
# @testitem "GINConv" setup=[TemporalConvTestModule, TestModule] begin
101188
# using .TemporalConvTestModule, .TestModule

GraphNeuralNetworks/test/test_module.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ function check_equal_leaves(a, b; rtol=1e-4, atol=1e-4)
7171
# @assert isapprox(x, y; rtol, atol)
7272
if !isapprox(x, y; rtol, atol)
7373
equal = false
74+
# @show x y
7475
end
7576
end
7677
end

0 commit comments

Comments
 (0)