@@ -22,24 +22,28 @@ e.g. `std = 0.1 * (maximum(input) - minimum(input))`.
2222## Keyword arguments
2323- `rng::AbstractRNG`: Specify the random number generator that is used to sample noise from the `distribution`.
2424 Defaults to `GLOBAL_RNG`.
25+ - `show_progress:Bool`: Show progress meter while sampling augmentations. Defaults to `true`.
2526"""
2627struct NoiseAugmentation{A<: AbstractXAIMethod ,D<: Sampleable ,R<: AbstractRNG } < :
2728 AbstractXAIMethod
2829 analyzer:: A
2930 n:: Int
3031 distribution:: D
3132 rng:: R
33+ show_progress:: Bool
3234
3335 function NoiseAugmentation (
34- analyzer:: A , n:: Int , distribution:: D , rng:: R = GLOBAL_RNG
36+ analyzer:: A , n:: Int , distribution:: D , rng:: R = GLOBAL_RNG, show_progress = true
3537 ) where {A<: AbstractXAIMethod ,D<: Sampleable ,R<: AbstractRNG }
3638 n < 1 && throw (ArgumentError (" Number of samples `n` needs to be larger than zero." ))
37- return new {A,D,R} (analyzer, n, distribution, rng)
39+ return new {A,D,R} (analyzer, n, distribution, rng, show_progress )
3840 end
3941end
40- function NoiseAugmentation (analyzer, n:: Int , std:: T = 1.0f0 , rng= GLOBAL_RNG) where {T<: Real }
42+ function NoiseAugmentation (
43+ analyzer, n:: Int , std:: T = 1.0f0 , rng= GLOBAL_RNG, show_progress= true
44+ ) where {T<: Real }
4145 distribution = Normal (zero (T), std^ 2 )
42- return NoiseAugmentation (analyzer, n, distribution, rng)
46+ return NoiseAugmentation (analyzer, n, distribution, rng, show_progress )
4347end
4448
4549function call_analyzer (input, aug:: NoiseAugmentation , ns:: AbstractOutputSelector ; kwargs... )
@@ -48,17 +52,21 @@ function call_analyzer(input, aug::NoiseAugmentation, ns::AbstractOutputSelector
4852 output_indices = ns (output)
4953 output_selector = AugmentationSelector (output_indices)
5054
55+ p = Progress (aug. n; desc= " Sampling NoiseAugmentation..." , enabled= aug. show_progress)
56+
5157 # First augmentation
52- input_aug = similar (input)
53- input_aug = sample_noise! (input_aug , input, aug)
54- expl_aug = aug. analyzer (input_aug , output_selector)
58+ noisy_input = similar (input)
59+ noisy_input = sample_noise! (noisy_input , input, aug)
60+ expl_aug = aug. analyzer (noisy_input , output_selector)
5561 sum_val = expl_aug. val
62+ next! (p)
5663
5764 # Further augmentations
5865 for _ in 2 : (aug. n)
59- input_aug = sample_noise! (input_aug, input, aug)
60- expl_aug = aug. analyzer (input_aug, output_selector)
61- sum_val += expl_aug. val
66+ noisy_input = sample_noise! (noisy_input, input, aug)
67+ expl_aug = aug. analyzer (noisy_input, output_selector)
68+ sum_val .+ = expl_aug. val
69+ next! (p)
6270 end
6371
6472 # Average explanation
7280function sample_noise! (
7381 out:: A , input:: A , aug:: NoiseAugmentation
7482) where {T,A<: AbstractArray{T} }
75- out .= input .+ rand (aug. rng, aug. distribution, size (input))
83+ out = rand! (aug. rng, aug. distribution, out)
84+ out .+ = input
85+ return out
7686end
7787
7888"""
@@ -114,9 +124,9 @@ function call_analyzer(
114124 # Further augmentations
115125 input_delta = (input - input_ref) / (aug. n - 1 )
116126 for _ in 1 : (aug. n)
117- input_aug += input_delta
127+ input_aug . += input_delta
118128 expl_aug = aug. analyzer (input_aug, output_selector)
119- sum_val += expl_aug. val
129+ sum_val . += expl_aug. val
120130 end
121131
122132 # Average gradients and compute explanation
0 commit comments