File tree Expand file tree Collapse file tree 3 files changed +44
-0
lines changed Expand file tree Collapse file tree 3 files changed +44
-0
lines changed Original file line number Diff line number Diff line change @@ -4,9 +4,11 @@ BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
4
4
Distributions = " 31c24e10-a181-5473-b8eb-7969acd0382f"
5
5
Flux = " 587475ba-b771-5e3f-ad9e-33799f191a9c"
6
6
JET = " c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
7
+ JLArrays = " 27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
7
8
JLD2 = " 033835bb-8acc-5ee8-8aae-3f567f8a3819"
8
9
JuliaFormatter = " 98e50ef6-434e-11e9-1051-2b60c6c9e899"
9
10
LinearAlgebra = " 37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
11
+ Metal = " dde4c033-4e86-420c-a63e-0dd931031962"
10
12
PkgJogger = " 10150987-6cc1-4b76-abee-b1c1cbd91c01"
11
13
Random = " 9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
12
14
ReferenceTests = " 324d217c-45ce-50fc-942e-d289b448e8cf"
Original file line number Diff line number Diff line change @@ -33,6 +33,9 @@ using JET
33
33
@info " Testing analyzers on batches..."
34
34
include (" test_batches.jl" )
35
35
end
36
+ @testset " GPU tests" begin
37
+ include (" test_gpu.jl" )
38
+ end
36
39
@testset " Benchmark correctness" begin
37
40
@info " Testing whether benchmarks are up-to-date..."
38
41
include (" test_benchmarks.jl" )
Original file line number Diff line number Diff line change
1
+ using ExplainableAI
2
+ using Test
3
+
4
+ using Flux
5
+ using Metal, JLArrays
6
+
7
+ if Metal. functional ()
8
+ @info " Using Metal as GPU device"
9
+ device = mtl # use Apple Metal locally
10
+ else
11
+ @info " Using JLArrays as GPU device"
12
+ device = jl # use JLArrays to fake GPU array
13
+ end
14
+
15
+ model = Chain (Dense (10 => 32 , relu), Dense (32 => 5 ))
16
+ input = rand (Float32, 10 , 8 )
17
+ @test_nowarn model (input)
18
+
19
+ model_gpu = device (model)
20
+ input_gpu = device (input)
21
+ @test_nowarn model_gpu (input_gpu)
22
+
23
+ analyzer_types = (Gradient, SmoothGrad, InputTimesGradient)
24
+
25
+ @testset " Run analyzer (CPU)" begin
26
+ @testset " $A " for A in analyzer_types
27
+ analyzer = A (model)
28
+ expl = analyze (input, analyzer)
29
+ @test expl isa Explanation
30
+ end
31
+ end
32
+
33
+ @testset " Run analyzer (GPU)" begin
34
+ @testset " $A " for A in analyzer_types
35
+ analyzer_gpu = A (model_gpu)
36
+ expl = analyze (input_gpu, analyzer_gpu)
37
+ @test expl isa Explanation
38
+ end
39
+ end
You can’t perform that action at this time.
0 commit comments