Skip to content

Commit a9a493b

Browse files
committed
Add GPU tests
1 parent 3377b66 commit a9a493b

File tree

3 files changed

+44
-0
lines changed

3 files changed

+44
-0
lines changed

test/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@ BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
44
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
55
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
66
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
7+
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
78
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
89
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
910
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
11+
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
1012
PkgJogger = "10150987-6cc1-4b76-abee-b1c1cbd91c01"
1113
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1214
ReferenceTests = "324d217c-45ce-50fc-942e-d289b448e8cf"

test/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ using JET
3333
@info "Testing analyzers on batches..."
3434
include("test_batches.jl")
3535
end
36+
@testset "GPU tests" begin
37+
include("test_gpu.jl")
38+
end
3639
@testset "Benchmark correctness" begin
3740
@info "Testing whether benchmarks are up-to-date..."
3841
include("test_benchmarks.jl")

test/test_gpu.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
using ExplainableAI
2+
using Test
3+
4+
using Flux
5+
using Metal, JLArrays
6+
7+
if Metal.functional()
8+
@info "Using Metal as GPU device"
9+
device = mtl # use Apple Metal locally
10+
else
11+
@info "Using JLArrays as GPU device"
12+
device = jl # use JLArrays to fake GPU array
13+
end
14+
15+
model = Chain(Dense(10 => 32, relu), Dense(32 => 5))
16+
input = rand(Float32, 10, 8)
17+
@test_nowarn model(input)
18+
19+
model_gpu = device(model)
20+
input_gpu = device(input)
21+
@test_nowarn model_gpu(input_gpu)
22+
23+
analyzer_types = (Gradient, SmoothGrad, InputTimesGradient)
24+
25+
@testset "Run analyzer (CPU)" begin
26+
@testset "$A" for A in analyzer_types
27+
analyzer = A(model)
28+
expl = analyze(input, analyzer)
29+
@test expl isa Explanation
30+
end
31+
end
32+
33+
@testset "Run analyzer (GPU)" begin
34+
@testset "$A" for A in analyzer_types
35+
analyzer_gpu = A(model_gpu)
36+
expl = analyze(input_gpu, analyzer_gpu)
37+
@test expl isa Explanation
38+
end
39+
end

0 commit comments

Comments
 (0)