Skip to content

Commit e79edc6

Browse files
fix
1 parent 3b0b083 commit e79edc6

File tree

4 files changed

+42
-12
lines changed

4 files changed

+42
-12
lines changed

GNNlib/test/msgpass.jl

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -168,8 +168,7 @@ end
168168

169169
@testset "copy_xj +" begin
170170
for g in TEST_GRAPHS
171-
dev = gpu_device(force=true)
172-
broken = get_graph_type(g) == :sparse && dev isa AMDGPUDevice
171+
broken = get_graph_type(g) == :sparse && gpu_backend() == "AMDGPU"
173172
f(g, x) = propagate(copy_xj, g, +, xj = x)
174173
@test test_gradients(
175174
f, g, g.x; test_gpu=true, test_grad_f=false, compare_finite_diff=false
@@ -179,8 +178,7 @@ end
179178

180179
@testset "copy_xj mean" begin
181180
for g in TEST_GRAPHS
182-
dev = gpu_device(force=true)
183-
broken = get_graph_type(g) == :sparse && dev isa AMDGPUDevice
181+
broken = get_graph_type(g) == :sparse && gpu_backend() == "AMDGPU"
184182
f(g, x) = propagate(copy_xj, g, mean, xj = x)
185183
@test test_gradients(
186184
f, g, g.x; test_gpu=true, test_grad_f=false, compare_finite_diff=false
@@ -190,8 +188,7 @@ end
190188

191189
@testset "e_mul_xj +" begin
192190
for g in TEST_GRAPHS
193-
dev = gpu_device(force=true)
194-
broken = get_graph_type(g) == :sparse && dev isa AMDGPUDevice
191+
broken = get_graph_type(g) == :sparse && gpu_backend() == "AMDGPU"
195192
e = rand(Float32, size(g.x, 1), g.num_edges)
196193
f(g, x, e) = propagate(e_mul_xj, g, +; xj = x, e)
197194
@test test_gradients(
@@ -207,7 +204,6 @@ end
207204
g = set_edge_weight(g, w)
208205
return propagate(w_mul_xj, g, +, xj = x)
209206
end
210-
dev = gpu_device(force=true)
211207
# @show get_graph_type(g) has_isolated_nodes(g)
212208
# broken = get_graph_type(g) == :sparse
213209
broken = true

GNNlib/test/test_module.jl

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ using Flux: Flux
4545
# from this module
4646
export D_IN, D_OUT, GRAPH_TYPES, TEST_GRAPHS,
4747
test_gradients, finitediff_withgradient,
48-
check_equal_leaves
48+
check_equal_leaves, gpu_backend
4949

5050

5151
const D_IN = 3
@@ -177,4 +177,18 @@ TEST_GRAPHS = [generate_test_graphs(:coo)...,
177177
generate_test_graphs(:dense)...,
178178
generate_test_graphs(:sparse)...]
179179

180+
181+
function gpu_backend()
182+
dev = gpu_device()
183+
if dev isa CUDADevice
184+
return "CUDA"
185+
elseif dev isa AMDGPUDevice
186+
return "AMDGPU"
187+
elseif dev isa MetalDevice
188+
return "Metal"
189+
else
190+
return "Unknown"
191+
end
192+
end
193+
180194
end # module

GraphNeuralNetworks/test/layers/conv.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ end
102102
l = ChebConv(D_IN => D_OUT, k)
103103
for g in TEST_GRAPHS
104104
has_isolated_nodes(g) && continue
105-
broken = get_graph_type(g) == :sparse || gpu_device() isa AMDGPUDevice
105+
broken = get_graph_type(g) == :sparse || gpu_backend() == "AMDGPU"
106106
@test size(l(g, g.x)) == (D_OUT, g.num_nodes) broken=broken
107107
@test test_gradients(
108108
l, g, g.x, rtol = RTOL_LOW, test_gpu = true, compare_finite_diff = false

GraphNeuralNetworks/test/test_module.jl

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,15 @@ using SparseArrays
3838
# from Base
3939
export mean, randn, SparseArrays, AbstractSparseMatrix
4040

41-
# from other packages
42-
export Flux, gradient, Dense, Chain, relu, random_regular_graph, erdos_renyi,
43-
BatchNorm, LayerNorm, Dropout, Parallel
41+
# from Flux.jl
42+
export Flux, gradient, Dense, Chain, relu
43+
BatchNorm, LayerNorm, Dropout, Parallel,
44+
gpu_device, cpu_device, get_device,
45+
CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice,
46+
gpu_backend
47+
48+
# from Graphs.jl
49+
export random_regular_graph, erdos_renyi
4450

4551
# from this module
4652
export D_IN, D_OUT, GRAPH_TYPES, TEST_GRAPHS,
@@ -178,5 +184,19 @@ TEST_GRAPHS = [generate_test_graphs(:coo)...,
178184
generate_test_graphs(:dense)...,
179185
generate_test_graphs(:sparse)...]
180186

187+
188+
function gpu_backend()
189+
dev = gpu_device()
190+
if dev isa CUDADevice
191+
return "CUDA"
192+
elseif dev isa AMDGPUDevice
193+
return "AMDGPU"
194+
elseif dev isa MetalDevice
195+
return "Metal"
196+
else
197+
return "Unknown"
198+
end
199+
end
200+
181201
end # testmodule
182202

0 commit comments

Comments
 (0)