Skip to content

Commit 3377b66

Browse files
committed
Add ProgressMeter to NoiseAugmentation
1 parent ca359ed commit 3377b66

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ version = "0.10.1"
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
88
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
99
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
10+
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
1011
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1112
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1213
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
@@ -16,6 +17,7 @@ XAIBase = "9b48221d-a747-4c1b-9860-46a1d8ba24a7"
1617
ADTypes = "1"
1718
DifferentiationInterface = "0.6"
1819
Distributions = "0.25"
20+
ProgressMeter = "1.10.4"
1921
Random = "<0.0.1, 1"
2022
Reexport = "1"
2123
Statistics = "<0.0.1, 1"

src/input_augmentation.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,17 +29,20 @@ struct NoiseAugmentation{A<:AbstractXAIMethod,D<:Sampleable,R<:AbstractRNG} <:
2929
n::Int
3030
distribution::D
3131
rng::R
32+
show_progress::Bool
3233

3334
function NoiseAugmentation(
34-
analyzer::A, n::Int, distribution::D, rng::R=GLOBAL_RNG
35+
analyzer::A, n::Int, distribution::D, rng::R=GLOBAL_RNG, show_progress=true
3536
) where {A<:AbstractXAIMethod,D<:Sampleable,R<:AbstractRNG}
3637
n < 1 && throw(ArgumentError("Number of samples `n` needs to be larger than zero."))
37-
return new{A,D,R}(analyzer, n, distribution, rng)
38+
return new{A,D,R}(analyzer, n, distribution, rng, show_progress)
3839
end
3940
end
40-
function NoiseAugmentation(analyzer, n::Int, std::T=1.0f0, rng=GLOBAL_RNG) where {T<:Real}
41+
function NoiseAugmentation(
42+
analyzer, n::Int, std::T=1.0f0, rng=GLOBAL_RNG, show_progress=true
43+
) where {T<:Real}
4144
distribution = Normal(zero(T), std^2)
42-
return NoiseAugmentation(analyzer, n, distribution, rng)
45+
return NoiseAugmentation(analyzer, n, distribution, rng, show_progress)
4346
end
4447

4548
function call_analyzer(input, aug::NoiseAugmentation, ns::AbstractOutputSelector; kwargs...)
@@ -48,17 +51,20 @@ function call_analyzer(input, aug::NoiseAugmentation, ns::AbstractOutputSelector
4851
output_indices = ns(output)
4952
output_selector = AugmentationSelector(output_indices)
5053

54+
p = Progress(aug.n; desc="Sampling NoiseAugmentation...", showspeed=aug.show_progress)
5155
# First augmentation
5256
input_aug = similar(input)
5357
input_aug = sample_noise!(input_aug, input, aug)
5458
expl_aug = aug.analyzer(input_aug, output_selector)
5559
sum_val = expl_aug.val
60+
next!(p)
5661

5762
# Further augmentations
5863
for _ in 2:(aug.n)
5964
input_aug = sample_noise!(input_aug, input, aug)
6065
expl_aug = aug.analyzer(input_aug, output_selector)
6166
sum_val += expl_aug.val
67+
next!(p)
6268
end
6369

6470
# Average explanation

0 commit comments

Comments
 (0)