Skip to content

Commit 60c5c88

Browse files
committed
check for CUDA before calling CUDA
1 parent cda3e9f commit 60c5c88

File tree

1 file changed

+14
-7
lines changed

1 file changed

+14
-7
lines changed

test/test_util_gpu.jl

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,25 @@
11
using Test
22
using HybridVariationalInference: HybridVariationalInference as HVI
33
using ComponentArrays
4-
using CUDA
4+
using MLDataDevices
5+
import CUDA, cuDNN
56
using FillArrays
67

78
@testset "ones_similar_x" begin
89
A = rand(Float64, 3, 4);
9-
B = CUDA.rand(Float32, 5, 2); # GPU matrix
1010
@test HVI.ones_similar_x(A, 3) isa FillArrays.AbstractFill #Vector
1111
@test HVI.ones_similar_x(A, size(A,1)) isa FillArrays.AbstractFill #Vector#Vector
12-
@test HVI.ones_similar_x(B, size(B,1)) isa CuArray
13-
@test HVI.ones_similar_x(ComponentVector(b=B), size(B,1)) isa CuArray
14-
@test HVI.ones_similar_x(B', size(B,1)) isa CuArray
15-
@test HVI.ones_similar_x(@view(B[:,2]), size(B,1)) isa CuArray
16-
@test HVI.ones_similar_x(ComponentVector(b=B)[:,1], size(B,1)) isa CuArray
12+
end
13+
14+
gdev = gpu_device()
15+
if gdev isa MLDataDevices.CUDADevice
16+
@testset "ones_similar_x" begin
17+
B = CUDA.rand(Float32, 5, 2); # GPU matrix
18+
@test HVI.ones_similar_x(B, size(B,1)) isa CuArray
19+
@test HVI.ones_similar_x(ComponentVector(b=B), size(B,1)) isa CuArray
20+
@test HVI.ones_similar_x(B', size(B,1)) isa CuArray
21+
@test HVI.ones_similar_x(@view(B[:,2]), size(B,1)) isa CuArray
22+
@test HVI.ones_similar_x(ComponentVector(b=B)[:,1], size(B,1)) isa CuArray
23+
end
1724
end
1825

0 commit comments

Comments
 (0)