diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7c6f9fb..cfb0363 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -22,7 +22,7 @@ jobs: version: - 'lts' - '1' - - 'pre' + # - 'pre' steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 diff --git a/CHANGELOG.md b/CHANGELOG.md index b6e04d4..794ea47 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ # ExplainableAI.jl +## Version `v0.10.2` +- ![Feature][badge-feature] Tested GPU support for `Gradient`, `InputTimesGradient`, `SmoothGrad`, `IntegratedGradients` ([#184]) +- ![Feature][badge-feature] `NoiseAugmentation`s show a progress meter by default. Turn off via `show_progress=false` ([#184]) + ## Version `v0.10.1` - ![Bugfix][badge-bugfix] Fix bug in `NoiseAugmentation` constructor ([#183]) @@ -227,6 +231,7 @@ Performance improvements: [VisionHeatmaps]: https://julia-xai.github.io/XAIDocs/VisionHeatmaps/stable/ [TextHeatmaps]: https://julia-xai.github.io/XAIDocs/TextHeatmaps/stable/ +[#184]: https://github.com/Julia-XAI/ExplainableAI.jl/pull/184 [#183]: https://github.com/Julia-XAI/ExplainableAI.jl/pull/183 [#180]: https://github.com/Julia-XAI/ExplainableAI.jl/pull/180 [#179]: https://github.com/Julia-XAI/ExplainableAI.jl/pull/179 diff --git a/Project.toml b/Project.toml index 8c56c08..ad994f8 100644 --- a/Project.toml +++ b/Project.toml @@ -1,12 +1,13 @@ name = "ExplainableAI" uuid = "4f1bc3e1-d60d-4ed0-9367-9bdff9846d3b" authors = ["Adrian Hill "] -version = "0.10.1" +version = "0.10.2-DEV" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" @@ -16,6 +17,7 @@ XAIBase = "9b48221d-a747-4c1b-9860-46a1d8ba24a7" ADTypes = "1" DifferentiationInterface = "0.6" Distributions = "0.25" +ProgressMeter = "1.10.4" Random = "<0.0.1, 1" Reexport = "1" Statistics = "<0.0.1, 1" diff --git a/src/ExplainableAI.jl b/src/ExplainableAI.jl index 2ed8ea9..f65dd53 100644 --- a/src/ExplainableAI.jl +++ b/src/ExplainableAI.jl @@ -6,7 +6,8 @@ import XAIBase: call_analyzer using Base.Iterators using Distributions: Distribution, Sampleable, Normal -using Random: AbstractRNG, GLOBAL_RNG +using Random: AbstractRNG, GLOBAL_RNG, rand! +using ProgressMeter: Progress, next! # Automatic differentiation using ADTypes: AbstractADType, AutoZygote diff --git a/src/input_augmentation.jl b/src/input_augmentation.jl index 993c416..8753faa 100644 --- a/src/input_augmentation.jl +++ b/src/input_augmentation.jl @@ -22,6 +22,7 @@ e.g. `std = 0.1 * (maximum(input) - minimum(input))`. ## Keyword arguments - `rng::AbstractRNG`: Specify the random number generator that is used to sample noise from the `distribution`. Defaults to `GLOBAL_RNG`. +- `show_progress:Bool`: Show progress meter while sampling augmentations. Defaults to `true`. """ struct NoiseAugmentation{A<:AbstractXAIMethod,D<:Sampleable,R<:AbstractRNG} <: AbstractXAIMethod @@ -29,17 +30,20 @@ struct NoiseAugmentation{A<:AbstractXAIMethod,D<:Sampleable,R<:AbstractRNG} <: n::Int distribution::D rng::R + show_progress::Bool function NoiseAugmentation( - analyzer::A, n::Int, distribution::D, rng::R=GLOBAL_RNG + analyzer::A, n::Int, distribution::D, rng::R=GLOBAL_RNG, show_progress=true ) where {A<:AbstractXAIMethod,D<:Sampleable,R<:AbstractRNG} n < 1 && throw(ArgumentError("Number of samples `n` needs to be larger than zero.")) - return new{A,D,R}(analyzer, n, distribution, rng) + return new{A,D,R}(analyzer, n, distribution, rng, show_progress) end end -function NoiseAugmentation(analyzer, n::Int, std::T=1.0f0, rng=GLOBAL_RNG) where {T<:Real} +function NoiseAugmentation( + analyzer, n::Int, std::T=1.0f0, rng=GLOBAL_RNG, show_progress=true +) where {T<:Real} distribution = Normal(zero(T), std^2) - return NoiseAugmentation(analyzer, n, distribution, rng) + return NoiseAugmentation(analyzer, n, distribution, rng, show_progress) end function call_analyzer(input, aug::NoiseAugmentation, ns::AbstractOutputSelector; kwargs...) @@ -48,17 +52,21 @@ function call_analyzer(input, aug::NoiseAugmentation, ns::AbstractOutputSelector output_indices = ns(output) output_selector = AugmentationSelector(output_indices) + p = Progress(aug.n; desc="Sampling NoiseAugmentation...", enabled=aug.show_progress) + # First augmentation - input_aug = similar(input) - input_aug = sample_noise!(input_aug, input, aug) - expl_aug = aug.analyzer(input_aug, output_selector) + noisy_input = similar(input) + noisy_input = sample_noise!(noisy_input, input, aug) + expl_aug = aug.analyzer(noisy_input, output_selector) sum_val = expl_aug.val + next!(p) # Further augmentations for _ in 2:(aug.n) - input_aug = sample_noise!(input_aug, input, aug) - expl_aug = aug.analyzer(input_aug, output_selector) - sum_val += expl_aug.val + noisy_input = sample_noise!(noisy_input, input, aug) + expl_aug = aug.analyzer(noisy_input, output_selector) + sum_val .+= expl_aug.val + next!(p) end # Average explanation @@ -72,7 +80,9 @@ end function sample_noise!( out::A, input::A, aug::NoiseAugmentation ) where {T,A<:AbstractArray{T}} - out .= input .+ rand(aug.rng, aug.distribution, size(input)) + out = rand!(aug.rng, aug.distribution, out) + out .+= input + return out end """ @@ -114,9 +124,9 @@ function call_analyzer( # Further augmentations input_delta = (input - input_ref) / (aug.n - 1) for _ in 1:(aug.n) - input_aug += input_delta + input_aug .+= input_delta expl_aug = aug.analyzer(input_aug, output_selector) - sum_val += expl_aug.val + sum_val .+= expl_aug.val end # Average gradients and compute explanation diff --git a/test/Project.toml b/test/Project.toml index a8c6b4b..1bf524e 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -4,9 +4,11 @@ BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" +JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Metal = "dde4c033-4e86-420c-a63e-0dd931031962" PkgJogger = "10150987-6cc1-4b76-abee-b1c1cbd91c01" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReferenceTests = "324d217c-45ce-50fc-942e-d289b448e8cf" diff --git a/test/references/cnn/IntegratedGradients_max.jld2 b/test/references/cnn/IntegratedGradients_max.jld2 index 4e3b68d..ad72aeb 100644 Binary files a/test/references/cnn/IntegratedGradients_max.jld2 and b/test/references/cnn/IntegratedGradients_max.jld2 differ diff --git a/test/runtests.jl b/test/runtests.jl index c99e335..356f7dd 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -33,6 +33,9 @@ using JET @info "Testing analyzers on batches..." include("test_batches.jl") end + @testset "GPU tests" begin + include("test_gpu.jl") + end @testset "Benchmark correctness" begin @info "Testing whether benchmarks are up-to-date..." include("test_benchmarks.jl") diff --git a/test/test_gpu.jl b/test/test_gpu.jl new file mode 100644 index 0000000..b7e112b --- /dev/null +++ b/test/test_gpu.jl @@ -0,0 +1,39 @@ +using ExplainableAI +using Test + +using Flux +using Metal, JLArrays + +if Metal.functional() + @info "Using Metal as GPU device" + device = mtl # use Apple Metal locally +else + @info "Using JLArrays as GPU device" + device = jl # use JLArrays to fake GPU array +end + +model = Chain(Dense(10 => 32, relu), Dense(32 => 5)) +input = rand(Float32, 10, 8) +@test_nowarn model(input) + +model_gpu = device(model) +input_gpu = device(input) +@test_nowarn model_gpu(input_gpu) + +analyzer_types = (Gradient, SmoothGrad, InputTimesGradient, IntegratedGradients) + +@testset "Run analyzer (CPU)" begin + @testset "$A" for A in analyzer_types + analyzer = A(model) + expl = analyze(input, analyzer) + @test expl isa Explanation + end +end + +@testset "Run analyzer (GPU)" begin + @testset "$A" for A in analyzer_types + analyzer_gpu = A(model_gpu) + expl = analyze(input_gpu, analyzer_gpu) + @test expl isa Explanation + end +end