Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
version:
- 'lts'
- '1'
- 'pre'
# - 'pre'
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
Expand Down
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# ExplainableAI.jl

## Version `v0.10.2`
- ![Feature][badge-feature] Tested GPU support for `Gradient`, `InputTimesGradient`, `SmoothGrad`, `IntegratedGradients` ([#184])
- ![Feature][badge-feature] `NoiseAugmentation`s show a progress meter by default. Turn off via `show_progress=false` ([#184])

## Version `v0.10.1`
- ![Bugfix][badge-bugfix] Fix bug in `NoiseAugmentation` constructor ([#183])

Expand Down Expand Up @@ -227,6 +231,7 @@ Performance improvements:
[VisionHeatmaps]: https://julia-xai.github.io/XAIDocs/VisionHeatmaps/stable/
[TextHeatmaps]: https://julia-xai.github.io/XAIDocs/TextHeatmaps/stable/

[#184]: https://github.com/Julia-XAI/ExplainableAI.jl/pull/184
[#183]: https://github.com/Julia-XAI/ExplainableAI.jl/pull/183
[#180]: https://github.com/Julia-XAI/ExplainableAI.jl/pull/180
[#179]: https://github.com/Julia-XAI/ExplainableAI.jl/pull/179
Expand Down
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
name = "ExplainableAI"
uuid = "4f1bc3e1-d60d-4ed0-9367-9bdff9846d3b"
authors = ["Adrian Hill <[email protected]>"]
version = "0.10.1"
version = "0.10.2-DEV"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Expand All @@ -16,6 +17,7 @@ XAIBase = "9b48221d-a747-4c1b-9860-46a1d8ba24a7"
ADTypes = "1"
DifferentiationInterface = "0.6"
Distributions = "0.25"
ProgressMeter = "1.10.4"
Random = "<0.0.1, 1"
Reexport = "1"
Statistics = "<0.0.1, 1"
Expand Down
3 changes: 2 additions & 1 deletion src/ExplainableAI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ import XAIBase: call_analyzer

using Base.Iterators
using Distributions: Distribution, Sampleable, Normal
using Random: AbstractRNG, GLOBAL_RNG
using Random: AbstractRNG, GLOBAL_RNG, rand!
using ProgressMeter: Progress, next!

# Automatic differentiation
using ADTypes: AbstractADType, AutoZygote
Expand Down
36 changes: 23 additions & 13 deletions src/input_augmentation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,24 +22,28 @@ e.g. `std = 0.1 * (maximum(input) - minimum(input))`.
## Keyword arguments
- `rng::AbstractRNG`: Specify the random number generator that is used to sample noise from the `distribution`.
Defaults to `GLOBAL_RNG`.
- `show_progress:Bool`: Show progress meter while sampling augmentations. Defaults to `true`.
"""
struct NoiseAugmentation{A<:AbstractXAIMethod,D<:Sampleable,R<:AbstractRNG} <:
AbstractXAIMethod
analyzer::A
n::Int
distribution::D
rng::R
show_progress::Bool

function NoiseAugmentation(
analyzer::A, n::Int, distribution::D, rng::R=GLOBAL_RNG
analyzer::A, n::Int, distribution::D, rng::R=GLOBAL_RNG, show_progress=true
) where {A<:AbstractXAIMethod,D<:Sampleable,R<:AbstractRNG}
n < 1 && throw(ArgumentError("Number of samples `n` needs to be larger than zero."))
return new{A,D,R}(analyzer, n, distribution, rng)
return new{A,D,R}(analyzer, n, distribution, rng, show_progress)
end
end
function NoiseAugmentation(analyzer, n::Int, std::T=1.0f0, rng=GLOBAL_RNG) where {T<:Real}
function NoiseAugmentation(
analyzer, n::Int, std::T=1.0f0, rng=GLOBAL_RNG, show_progress=true
) where {T<:Real}
distribution = Normal(zero(T), std^2)
return NoiseAugmentation(analyzer, n, distribution, rng)
return NoiseAugmentation(analyzer, n, distribution, rng, show_progress)
end

function call_analyzer(input, aug::NoiseAugmentation, ns::AbstractOutputSelector; kwargs...)
Expand All @@ -48,17 +52,21 @@ function call_analyzer(input, aug::NoiseAugmentation, ns::AbstractOutputSelector
output_indices = ns(output)
output_selector = AugmentationSelector(output_indices)

p = Progress(aug.n; desc="Sampling NoiseAugmentation...", enabled=aug.show_progress)

# First augmentation
input_aug = similar(input)
input_aug = sample_noise!(input_aug, input, aug)
expl_aug = aug.analyzer(input_aug, output_selector)
noisy_input = similar(input)
noisy_input = sample_noise!(noisy_input, input, aug)
expl_aug = aug.analyzer(noisy_input, output_selector)
sum_val = expl_aug.val
next!(p)

# Further augmentations
for _ in 2:(aug.n)
input_aug = sample_noise!(input_aug, input, aug)
expl_aug = aug.analyzer(input_aug, output_selector)
sum_val += expl_aug.val
noisy_input = sample_noise!(noisy_input, input, aug)
expl_aug = aug.analyzer(noisy_input, output_selector)
sum_val .+= expl_aug.val
next!(p)
end

# Average explanation
Expand All @@ -72,7 +80,9 @@ end
function sample_noise!(
out::A, input::A, aug::NoiseAugmentation
) where {T,A<:AbstractArray{T}}
out .= input .+ rand(aug.rng, aug.distribution, size(input))
out = rand!(aug.rng, aug.distribution, out)
out .+= input
return out
end

"""
Expand Down Expand Up @@ -114,9 +124,9 @@ function call_analyzer(
# Further augmentations
input_delta = (input - input_ref) / (aug.n - 1)
for _ in 1:(aug.n)
input_aug += input_delta
input_aug .+= input_delta
expl_aug = aug.analyzer(input_aug, output_selector)
sum_val += expl_aug.val
sum_val .+= expl_aug.val
end

# Average gradients and compute explanation
Expand Down
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@ BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
PkgJogger = "10150987-6cc1-4b76-abee-b1c1cbd91c01"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReferenceTests = "324d217c-45ce-50fc-942e-d289b448e8cf"
Expand Down
Binary file modified test/references/cnn/IntegratedGradients_max.jld2
Binary file not shown.
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ using JET
@info "Testing analyzers on batches..."
include("test_batches.jl")
end
@testset "GPU tests" begin
include("test_gpu.jl")
end
@testset "Benchmark correctness" begin
@info "Testing whether benchmarks are up-to-date..."
include("test_benchmarks.jl")
Expand Down
39 changes: 39 additions & 0 deletions test/test_gpu.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
using ExplainableAI
using Test

using Flux
using Metal, JLArrays

if Metal.functional()
@info "Using Metal as GPU device"
device = mtl # use Apple Metal locally
else
@info "Using JLArrays as GPU device"
device = jl # use JLArrays to fake GPU array
end

model = Chain(Dense(10 => 32, relu), Dense(32 => 5))
input = rand(Float32, 10, 8)
@test_nowarn model(input)

model_gpu = device(model)
input_gpu = device(input)
@test_nowarn model_gpu(input_gpu)

analyzer_types = (Gradient, SmoothGrad, InputTimesGradient, IntegratedGradients)

@testset "Run analyzer (CPU)" begin
@testset "$A" for A in analyzer_types
analyzer = A(model)
expl = analyze(input, analyzer)
@test expl isa Explanation
end
end

@testset "Run analyzer (GPU)" begin
@testset "$A" for A in analyzer_types
analyzer_gpu = A(model_gpu)
expl = analyze(input_gpu, analyzer_gpu)
@test expl isa Explanation
end
end
Loading