@@ -8,80 +8,6 @@ struct AugmentationSelector{I} <: AbstractOutputSelector
8
8
end
9
9
(s:: AugmentationSelector )(out) = s. indices
10
10
11
- """
12
- augment_batch_dim(input, n)
13
-
14
- Repeat each sample in input batch n-times along batch dimension.
15
- This turns arrays of size `(..., B)` into arrays of size `(..., B*n)`.
16
-
17
- ## Example
18
- ```julia-repl
19
- julia> A = [1 2; 3 4]
20
- 2×2 Matrix{Int64}:
21
- 1 2
22
- 3 4
23
-
24
- julia> augment_batch_dim(A, 3)
25
- 2×6 Matrix{Int64}:
26
- 1 1 1 2 2 2
27
- 3 3 3 4 4 4
28
- ```
29
- """
30
- function augment_batch_dim (input:: AbstractArray{T,N} , n) where {T,N}
31
- return repeat (input; inner= (ntuple (Returns (1 ), N - 1 )... , n))
32
- end
33
-
34
- """
35
- reduce_augmentation(augmented_input, n)
36
-
37
- Reduce augmented input batch by averaging the explanation for each augmented sample.
38
- """
39
- function reduce_augmentation (input:: AbstractArray{T,N} , n) where {T<: AbstractFloat ,N}
40
- # Allocate output array
41
- in_size = size (input)
42
- in_size[end ] % n != 0 &&
43
- throw (ArgumentError (" Can't reduce augmented batch size of $(in_size[end ]) by $n " ))
44
- out_size = (in_size[1 : (end - 1 )]. .. , div (in_size[end ], n))
45
- out = similar (input, eltype (input), out_size)
46
-
47
- axs = axes (input, N)
48
- colons = ntuple (Returns (:), N - 1 )
49
- for (i, ax) in enumerate (first (axs): n: last (axs))
50
- view (out, colons... , i) .=
51
- dropdims (sum (view (input, colons... , ax: (ax + n - 1 )); dims= N); dims= N) / n
52
- end
53
- return out
54
- end
55
-
56
- """
57
- augment_indices(indices, n)
58
-
59
- Strip batch indices and return indices for batch augmented by n samples.
60
-
61
- ## Example
62
- ```julia-repl
63
- julia> inds = [CartesianIndex(5,1), CartesianIndex(3,2)]
64
- 2-element Vector{CartesianIndex{2}}:
65
- CartesianIndex(5, 1)
66
- CartesianIndex(3, 2)
67
-
68
- julia> augment_indices(inds, 3)
69
- 6-element Vector{CartesianIndex{2}}:
70
- CartesianIndex(5, 1)
71
- CartesianIndex(5, 2)
72
- CartesianIndex(5, 3)
73
- CartesianIndex(3, 4)
74
- CartesianIndex(3, 5)
75
- CartesianIndex(3, 6)
76
- ```
77
- """
78
- function augment_indices (inds:: Vector{CartesianIndex{N}} , n) where {N}
79
- indices_wo_batch = [i. I[1 : (end - 1 )] for i in inds]
80
- return map (enumerate (repeat (indices_wo_batch; inner= n))) do (i, idx)
81
- CartesianIndex {N} (idx... , i)
82
- end
83
- end
84
-
85
11
"""
86
12
NoiseAugmentation(analyzer, n)
87
13
NoiseAugmentation(analyzer, n, std::Real)
@@ -104,38 +30,53 @@ struct NoiseAugmentation{A<:AbstractXAIMethod,D<:Sampleable,R<:AbstractRNG} <:
104
30
n:: Int
105
31
distribution:: D
106
32
rng:: R
107
- end
108
- function NoiseAugmentation (analyzer, n, distribution:: Sampleable , rng= GLOBAL_RNG)
109
- return NoiseAugmentation (analyzer, n, distribution:: Sampleable , rng)
33
+
34
+ function NoiseAugmentation (
35
+ analyzer:: A , n:: Int , distribution:: D , rng:: R
36
+ ) where {A<: AbstractXAIMethod ,D<: Sampleable ,R<: AbstractRNG }
37
+ n < 2 &&
38
+ throw (ArgumentError (" Number of noise samples `n` needs to be larger than one." ))
39
+ return new {A,D,R} (analyzer, n, distribution, rng)
40
+ end
110
41
end
111
42
function NoiseAugmentation (analyzer, n, std:: T = 1.0f0 , rng= GLOBAL_RNG) where {T<: Real }
112
43
return NoiseAugmentation (analyzer, n, Normal (zero (T), std^ 2 ), rng)
113
44
end
45
+ function NoiseAugmentation (analyzer, n, distribution:: Sampleable , rng= GLOBAL_RNG)
46
+ return NoiseAugmentation (analyzer, n, distribution, rng)
47
+ end
114
48
115
49
function call_analyzer (input, aug:: NoiseAugmentation , ns:: AbstractOutputSelector ; kwargs... )
116
50
# Regular forward pass of model
117
51
output = aug. analyzer. model (input)
118
52
output_indices = ns (output)
119
-
120
- # Call regular analyzer on augmented batch
121
- augmented_input = add_noise (augment_batch_dim (input, aug. n), aug. distribution, aug. rng)
122
- augmented_indices = augment_indices (output_indices, aug. n)
123
- augmented_expl = aug. analyzer (augmented_input, AugmentationSelector (augmented_indices))
53
+ output_selector = AugmentationSelector (output_indices)
54
+
55
+ # First augmentation
56
+ input_aug = similar (input)
57
+ input_aug = sample_noise! (input_aug, input, aug)
58
+ expl_aug = aug. analyzer (input_aug, output_selector)
59
+ sum_val = expl_aug. val
60
+
61
+ # Further augmentations
62
+ for _ in 2 : (aug. n)
63
+ input_aug = sample_noise! (input_aug, input, aug)
64
+ expl_aug = aug. analyzer (input_aug, output_selector)
65
+ sum_val += expl_aug. val
66
+ end
124
67
125
68
# Average explanation
69
+ val = sum_val / aug. n
70
+
126
71
return Explanation (
127
- reduce_augmentation (augmented_expl. val, aug. n),
128
- input,
129
- output,
130
- output_indices,
131
- augmented_expl. analyzer,
132
- augmented_expl. heatmap,
133
- nothing ,
72
+ val, input, output, output_indices, expl_aug. analyzer, expl_aug. heatmap, nothing
134
73
)
135
74
end
136
75
137
- function add_noise (A:: AbstractArray{T} , distr:: Distribution , rng:: AbstractRNG ) where {T}
138
- return A + T .(rand (rng, distr, size (A)))
76
+ function sample_noise! (
77
+ out:: A , input:: A , aug:: NoiseAugmentation
78
+ ) where {T,A<: AbstractArray{T} }
79
+ out .= input .+ rand (aug. rng, aug. distribution, size (input))
139
80
end
140
81
141
82
"""
@@ -149,6 +90,13 @@ difference between the input and the reference input.
149
90
struct InterpolationAugmentation{A<: AbstractXAIMethod } <: AbstractXAIMethod
150
91
analyzer:: A
151
92
n:: Int
93
+
94
+ function InterpolationAugmentation (analyzer:: A , n:: Int ) where {A<: AbstractXAIMethod }
95
+ n < 2 && throw (
96
+ ArgumentError (" Number of interpolation steps `n` needs to be larger than one." ),
97
+ )
98
+ return new {A} (analyzer, n)
99
+ end
152
100
end
153
101
154
102
function call_analyzer (
@@ -160,57 +108,25 @@ function call_analyzer(
160
108
# Regular forward pass of model
161
109
output = aug. analyzer. model (input)
162
110
output_indices = ns (output)
163
-
164
- # Call regular analyzer on augmented batch
165
- augmented_input = interpolate_batch (input, input_ref, aug. n)
166
- augmented_indices = augment_indices (output_indices, aug. n)
167
- augmented_expl = aug. analyzer (augmented_input, AugmentationSelector (augmented_indices))
111
+ output_selector = AugmentationSelector (output_indices)
112
+
113
+ # First augmentations
114
+ input_aug = input_ref
115
+ expl_aug = aug. analyzer (input_aug, output_selector)
116
+ sum_val = expl_aug. val
117
+
118
+ # Further augmentations
119
+ input_delta = (input - input_ref) / (aug. n - 1 )
120
+ for _ in 1 : (aug. n)
121
+ input_aug += input_delta
122
+ expl_aug = aug. analyzer (input_aug, output_selector)
123
+ sum_val += expl_aug. val
124
+ end
168
125
169
126
# Average gradients and compute explanation
170
- expl = (input - input_ref) .* reduce_augmentation (augmented_expl . val, aug. n)
127
+ val = (input - input_ref) .* sum_val / aug. n
171
128
172
129
return Explanation (
173
- expl,
174
- input,
175
- output,
176
- output_indices,
177
- augmented_expl. analyzer,
178
- augmented_expl. heatmap,
179
- nothing ,
130
+ val, input, output, output_indices, expl_aug. analyzer, expl_aug. heatmap, nothing
180
131
)
181
132
end
182
-
183
- """
184
- interpolate_batch(x, x0, nsamples)
185
-
186
- Augment batch along batch dimension using linear interpolation between input `x` and a reference input `x0`.
187
-
188
- ## Example
189
- ```julia-repl
190
- julia> x = Float16.(reshape(1:4, 2, 2))
191
- 2×2 Matrix{Float16}:
192
- 1.0 3.0
193
- 2.0 4.0
194
-
195
- julia> x0 = zero(x)
196
- 2×2 Matrix{Float16}:
197
- 0.0 0.0
198
- 0.0 0.0
199
-
200
- julia> interpolate_batch(x, x0, 5)
201
- 2×10 Matrix{Float16}:
202
- 0.0 0.25 0.5 0.75 1.0 0.0 0.75 1.5 2.25 3.0
203
- 0.0 0.5 1.0 1.5 2.0 0.0 1.0 2.0 3.0 4.0
204
- ```
205
- """
206
- function interpolate_batch (
207
- x:: AbstractArray{T,N} , x0:: AbstractArray{T,N} , nsamples
208
- ) where {T,N}
209
- in_size = size (x)
210
- outs = similar (x, (in_size[1 : (end - 1 )]. .. , in_size[end ] * nsamples))
211
- colons = ntuple (Returns (:), N - 1 )
212
- for (i, t) in enumerate (range (zero (T), oneunit (T); length= nsamples))
213
- outs[colons... , i: nsamples: end ] .= x0 + t * (x - x0)
214
- end
215
- return outs
216
- end
0 commit comments