@@ -22,24 +22,28 @@ e.g. `std = 0.1 * (maximum(input) - minimum(input))`.
22
22
## Keyword arguments
23
23
- `rng::AbstractRNG`: Specify the random number generator that is used to sample noise from the `distribution`.
24
24
Defaults to `GLOBAL_RNG`.
25
+ - `show_progress:Bool`: Show progress meter while sampling augmentations. Defaults to `true`.
25
26
"""
26
27
struct NoiseAugmentation{A<: AbstractXAIMethod ,D<: Sampleable ,R<: AbstractRNG } < :
27
28
AbstractXAIMethod
28
29
analyzer:: A
29
30
n:: Int
30
31
distribution:: D
31
32
rng:: R
33
+ show_progress:: Bool
32
34
33
35
function NoiseAugmentation (
34
- analyzer:: A , n:: Int , distribution:: D , rng:: R = GLOBAL_RNG
36
+ analyzer:: A , n:: Int , distribution:: D , rng:: R = GLOBAL_RNG, show_progress = true
35
37
) where {A<: AbstractXAIMethod ,D<: Sampleable ,R<: AbstractRNG }
36
38
n < 1 && throw (ArgumentError (" Number of samples `n` needs to be larger than zero." ))
37
- return new {A,D,R} (analyzer, n, distribution, rng)
39
+ return new {A,D,R} (analyzer, n, distribution, rng, show_progress )
38
40
end
39
41
end
40
- function NoiseAugmentation (analyzer, n:: Int , std:: T = 1.0f0 , rng= GLOBAL_RNG) where {T<: Real }
42
+ function NoiseAugmentation (
43
+ analyzer, n:: Int , std:: T = 1.0f0 , rng= GLOBAL_RNG, show_progress= true
44
+ ) where {T<: Real }
41
45
distribution = Normal (zero (T), std^ 2 )
42
- return NoiseAugmentation (analyzer, n, distribution, rng)
46
+ return NoiseAugmentation (analyzer, n, distribution, rng, show_progress )
43
47
end
44
48
45
49
function call_analyzer (input, aug:: NoiseAugmentation , ns:: AbstractOutputSelector ; kwargs... )
@@ -48,17 +52,21 @@ function call_analyzer(input, aug::NoiseAugmentation, ns::AbstractOutputSelector
48
52
output_indices = ns (output)
49
53
output_selector = AugmentationSelector (output_indices)
50
54
55
+ p = Progress (aug. n; desc= " Sampling NoiseAugmentation..." , enabled= aug. show_progress)
56
+
51
57
# First augmentation
52
- input_aug = similar (input)
53
- input_aug = sample_noise! (input_aug , input, aug)
54
- expl_aug = aug. analyzer (input_aug , output_selector)
58
+ noisy_input = similar (input)
59
+ noisy_input = sample_noise! (noisy_input , input, aug)
60
+ expl_aug = aug. analyzer (noisy_input , output_selector)
55
61
sum_val = expl_aug. val
62
+ next! (p)
56
63
57
64
# Further augmentations
58
65
for _ in 2 : (aug. n)
59
- input_aug = sample_noise! (input_aug, input, aug)
60
- expl_aug = aug. analyzer (input_aug, output_selector)
61
- sum_val += expl_aug. val
66
+ noisy_input = sample_noise! (noisy_input, input, aug)
67
+ expl_aug = aug. analyzer (noisy_input, output_selector)
68
+ sum_val .+ = expl_aug. val
69
+ next! (p)
62
70
end
63
71
64
72
# Average explanation
72
80
function sample_noise! (
73
81
out:: A , input:: A , aug:: NoiseAugmentation
74
82
) where {T,A<: AbstractArray{T} }
75
- out .= input .+ rand (aug. rng, aug. distribution, size (input))
83
+ out = rand! (aug. rng, aug. distribution, out)
84
+ out .+ = input
85
+ return out
76
86
end
77
87
78
88
"""
@@ -114,9 +124,9 @@ function call_analyzer(
114
124
# Further augmentations
115
125
input_delta = (input - input_ref) / (aug. n - 1 )
116
126
for _ in 1 : (aug. n)
117
- input_aug += input_delta
127
+ input_aug . += input_delta
118
128
expl_aug = aug. analyzer (input_aug, output_selector)
119
- sum_val += expl_aug. val
129
+ sum_val . += expl_aug. val
120
130
end
121
131
122
132
# Average gradients and compute explanation
0 commit comments