Skip to content

Commit 6b3ee3b

Browse files
gix gnn test
1 parent d207de4 commit 6b3ee3b

File tree

2 files changed

+16
-18
lines changed

2 files changed

+16
-18
lines changed

GNNlib/test/test_module.jl

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -59,15 +59,17 @@ function finitediff_withgradient(f, x...)
5959
end
6060

6161
function check_equal_leaves(a, b; rtol=1e-4, atol=1e-4)
62+
equal = true
6263
fmapstructure_with_path(a, b) do kp, x, y
6364
if x isa AbstractArray
6465
# @show kp
65-
@assert x y rtol=rtol atol=atol
66-
# elseif x isa Number
67-
# @show kp
68-
# @assert x ≈ y rtol=rtol atol=atol
66+
# @assert isapprox(x, y; rtol, atol)
67+
if !isapprox(x, y; rtol, atol)
68+
equal = false
69+
end
6970
end
7071
end
72+
@assert equal
7173
end
7274

7375
function test_gradients(
@@ -109,15 +111,15 @@ function test_gradients(
109111
f64 = f |> Flux.f64
110112
xs64 = xs .|> Flux.f64
111113
y_fd, g_fd = finitediff_withgradient((xs...) -> loss(f64, graph, xs...), xs64...)
112-
@assert y y_fd rtol=rtol atol=atol
114+
@assert isapprox(y, y_fd; rtol, atol)
113115
check_equal_leaves(g, g_fd; rtol, atol)
114116
end
115117

116118
if test_gpu
117119
# Zygote gradient with respect to input on GPU.
118120
y_gpu, g_gpu = Zygote.withgradient((xs...) -> loss(f_gpu, graph_gpu, xs...), xs_gpu...)
119121
@assert get_device(g_gpu) == get_device(xs_gpu)
120-
@assert y_gpu y rtol=rtol atol=atol
122+
@assert isapprox(y_gpu, y; rtol, atol)
121123
check_equal_leaves(g_gpu |> cpu_dev, g; rtol, atol)
122124
end
123125
end
@@ -132,23 +134,22 @@ function test_gradients(
132134
ps, re = Flux.destructure(f64)
133135
y_fd, g_fd = finitediff_withgradient(ps -> loss(re(ps),graph, xs...), ps)
134136
g_fd = (re(g_fd[1]),)
135-
@assert y y_fd rtol=rtol atol=atol
137+
@assert isapprox(y, y_fd; rtol, atol)
136138
check_equal_leaves(g, g_fd; rtol, atol)
137139
end
138140

139141
if test_gpu
140142
# Zygote gradient with respect to f on GPU.
141143
y_gpu, g_gpu = Zygote.withgradient(f -> loss(f,graph_gpu, xs_gpu...), f_gpu)
142144
# @assert get_device(g_gpu) == get_device(xs_gpu)
143-
@assert y_gpu y rtol=rtol atol=atol
145+
@assert isapprox(y_gpu, y; rtol, atol)
144146
check_equal_leaves(g_gpu |> cpu_dev, g; rtol, atol)
145147
end
146148
end
147149
@test true # if we reach here, the test passed
148-
return true # return true in case we want to put a @test_broken in the caller
150+
return true
149151
end
150152

151-
152153
function generate_test_graphs(graph_type)
153154
adj1 = [0 1 0 1
154155
1 0 1 0

GraphNeuralNetworks/test/test_module.jl

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,10 @@ function check_equal_leaves(a, b; rtol=1e-4, atol=1e-4)
6363
fmapstructure_with_path(a, b) do kp, x, y
6464
if x isa AbstractArray
6565
# @show kp
66-
# @assert x ≈ y rtol=rtol atol=atol
66+
# @assert isapprox(x, y; rtol, atol)
6767
if !isapprox(x, y; rtol, atol)
6868
equal = false
6969
end
70-
# elseif x isa Number
71-
# @show kp
72-
# @assert x ≈ y rtol=rtol atol=atol
7370
end
7471
end
7572
@assert equal
@@ -114,15 +111,15 @@ function test_gradients(
114111
f64 = f |> Flux.f64
115112
xs64 = xs .|> Flux.f64
116113
y_fd, g_fd = finitediff_withgradient((xs...) -> loss(f64, graph, xs...), xs64...)
117-
@assert y y_fd rtol=rtol atol=atol
114+
@assert isapprox(y, y_fd; rtol, atol)
118115
check_equal_leaves(g, g_fd; rtol, atol)
119116
end
120117

121118
if test_gpu
122119
# Zygote gradient with respect to input on GPU.
123120
y_gpu, g_gpu = Zygote.withgradient((xs...) -> loss(f_gpu, graph_gpu, xs...), xs_gpu...)
124121
@assert get_device(g_gpu) == get_device(xs_gpu)
125-
@assert y_gpu y rtol=rtol atol=atol
122+
@assert isapprox(y_gpu, y; rtol, atol)
126123
check_equal_leaves(g_gpu |> cpu_dev, g; rtol, atol)
127124
end
128125
end
@@ -137,15 +134,15 @@ function test_gradients(
137134
ps, re = Flux.destructure(f64)
138135
y_fd, g_fd = finitediff_withgradient(ps -> loss(re(ps),graph, xs...), ps)
139136
g_fd = (re(g_fd[1]),)
140-
@assert y y_fd rtol=rtol atol=atol
137+
@assert isapprox(y, y_fd; rtol, atol)
141138
check_equal_leaves(g, g_fd; rtol, atol)
142139
end
143140

144141
if test_gpu
145142
# Zygote gradient with respect to f on GPU.
146143
y_gpu, g_gpu = Zygote.withgradient(f -> loss(f,graph_gpu, xs_gpu...), f_gpu)
147144
# @assert get_device(g_gpu) == get_device(xs_gpu)
148-
@assert y_gpu y rtol=rtol atol=atol
145+
@assert isapprox(y_gpu, y; rtol, atol)
149146
check_equal_leaves(g_gpu |> cpu_dev, g; rtol, atol)
150147
end
151148
end

0 commit comments

Comments
 (0)