Skip to content

Commit d409d90

Browse files
fix gpu tests
1 parent f49aae3 commit d409d90

File tree

4 files changed

+10
-10
lines changed

4 files changed

+10
-10
lines changed

src/msgpass.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -146,10 +146,10 @@ copyxj(xi, xj, e) = xj
146146
# ximulxj(xi, xj, e) = xi .* xj
147147
# xiaddxj(xi, xj, e) = xi .+ xj
148148

149-
function propagate(::typeof(copyxj), g::GNNGraph, ::typeof(+), xi, xj::AbstractMatrix, e)
150-
A = adjacency_matrix(g)
151-
return xj * A
152-
end
149+
# function propagate(::typeof(copyxj), g::GNNGraph, ::typeof(+), xi, xj::AbstractMatrix, e)
150+
# A = adjacency_matrix(g)
151+
# return xj * A
152+
# end
153153

154154
# function propagate(::typeof(copyxj), g::GNNGraph, ::typeof(mean), xi, xj::AbstractMatrix, e)
155155
# A = adjacency_matrix(g)

test/layers/basic.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
testmode!(gnn)
1717

18-
test_layer(gnn, g, rtol=1e-5)
18+
test_layer(gnn, g, rtol=1e-5, exclude_grad_fields=[, :σ²])
1919

2020

2121
@testset "Parallel" begin
@@ -29,7 +29,7 @@
2929

3030
testmode!(gnn)
3131

32-
test_layer(gnn, g, rtol=1e-5)
32+
test_layer(gnn, g, rtol=1e-5, exclude_grad_fields=[, :σ²])
3333
end
3434
end
3535
end

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ tests = [
2727
!CUDA.functional() && @warn("CUDA unavailable, not testing GPU support")
2828

2929
# Testing all graph types. :sparse is a bit broken at the moment
30-
@testset "GraphNeuralNetworks: graph format $graph_type" for graph_type in (:coo,)
30+
@testset "GraphNeuralNetworks: graph format $graph_type" for graph_type in (:coo,:sparse,:dense)
3131

3232
global GRAPH_T = graph_type
3333
global TEST_GPU = CUDA.functional() && GRAPH_T != :sparse

test/test_utils.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,6 @@ function test_layer(l, g::GNNGraph; atol = 1e-7, rtol = 1e-5,
9999

100100
# TEST LAYER GRADIENT - l(g, x)
101101
= gradient(l -> loss(l, g, x), l)[1]
102-
=isa Base.RefValue ? l̄[] :# Zygote wraps gradient of mutables in RefValue
103102
l̄_fd = FiniteDifferences.grad(fdm, l64 -> loss(l64, g64, x64), l64)[1]
104103
test_approx_structs(l, l̄, l̄_fd; atol, rtol, broken_grad_fields, exclude_grad_fields, verbose)
105104

@@ -110,7 +109,6 @@ function test_layer(l, g::GNNGraph; atol = 1e-7, rtol = 1e-5,
110109

111110
# TEST LAYER GRADIENT - l(g)
112111
= gradient(l -> loss(l, g), l)[1]
113-
=isa Base.RefValue ? l̄[] :# Zygote wraps gradient of mutables in RefValue
114112
l̄_fd = FiniteDifferences.grad(fdm, l64 -> loss(l64, g64), l64)[1]
115113
test_approx_structs(l, l̄, l̄_fd; atol, rtol, broken_grad_fields, exclude_grad_fields, verbose)
116114

@@ -122,6 +120,9 @@ function test_approx_structs(l, l̄, l̄2; atol=1e-5, rtol=1e-5,
122120
exclude_grad_fields=[],
123121
verbose=false)
124122

123+
=isa Base.RefValue ? l̄[] :# Zygote wraps gradient of mutables in RefValue
124+
l̄2 = l̄2 isa Base.RefValue ? l̄2[] : l̄2 # Zygote wraps gradient of mutables in RefValue
125+
125126
for f in fieldnames(typeof(l))
126127
f exclude_grad_fields && continue
127128
f̄, f̄2 = getfield(l̄, f), getfield(l̄2, f)
@@ -147,7 +148,6 @@ function test_approx_structs(l, l̄, l̄2; atol=1e-5, rtol=1e-5,
147148
end
148149
else
149150
verbose && println("C")
150-
=isa Base.RefValue ? f̄[] :# Zygote wraps gradient of mutables in RefValue
151151
test_approx_structs(x, f̄, f̄2; exclude_grad_fields, broken_grad_fields, verbose)
152152
end
153153
end

0 commit comments

Comments
 (0)