Skip to content

Commit efc29f8

Browse files
authored
Constant memory input augmentations (#180)
1 parent bd0e400 commit efc29f8

File tree

4 files changed

+55
-195
lines changed

4 files changed

+55
-195
lines changed

src/input_augmentation.jl

Lines changed: 55 additions & 139 deletions
Original file line numberDiff line numberDiff line change
@@ -8,80 +8,6 @@ struct AugmentationSelector{I} <: AbstractOutputSelector
88
end
99
(s::AugmentationSelector)(out) = s.indices
1010

11-
"""
12-
augment_batch_dim(input, n)
13-
14-
Repeat each sample in input batch n-times along batch dimension.
15-
This turns arrays of size `(..., B)` into arrays of size `(..., B*n)`.
16-
17-
## Example
18-
```julia-repl
19-
julia> A = [1 2; 3 4]
20-
2×2 Matrix{Int64}:
21-
1 2
22-
3 4
23-
24-
julia> augment_batch_dim(A, 3)
25-
2×6 Matrix{Int64}:
26-
1 1 1 2 2 2
27-
3 3 3 4 4 4
28-
```
29-
"""
30-
function augment_batch_dim(input::AbstractArray{T,N}, n) where {T,N}
31-
return repeat(input; inner=(ntuple(Returns(1), N - 1)..., n))
32-
end
33-
34-
"""
35-
reduce_augmentation(augmented_input, n)
36-
37-
Reduce augmented input batch by averaging the explanation for each augmented sample.
38-
"""
39-
function reduce_augmentation(input::AbstractArray{T,N}, n) where {T<:AbstractFloat,N}
40-
# Allocate output array
41-
in_size = size(input)
42-
in_size[end] % n != 0 &&
43-
throw(ArgumentError("Can't reduce augmented batch size of $(in_size[end]) by $n"))
44-
out_size = (in_size[1:(end - 1)]..., div(in_size[end], n))
45-
out = similar(input, eltype(input), out_size)
46-
47-
axs = axes(input, N)
48-
colons = ntuple(Returns(:), N - 1)
49-
for (i, ax) in enumerate(first(axs):n:last(axs))
50-
view(out, colons..., i) .=
51-
dropdims(sum(view(input, colons..., ax:(ax + n - 1)); dims=N); dims=N) / n
52-
end
53-
return out
54-
end
55-
56-
"""
57-
augment_indices(indices, n)
58-
59-
Strip batch indices and return indices for batch augmented by n samples.
60-
61-
## Example
62-
```julia-repl
63-
julia> inds = [CartesianIndex(5,1), CartesianIndex(3,2)]
64-
2-element Vector{CartesianIndex{2}}:
65-
CartesianIndex(5, 1)
66-
CartesianIndex(3, 2)
67-
68-
julia> augment_indices(inds, 3)
69-
6-element Vector{CartesianIndex{2}}:
70-
CartesianIndex(5, 1)
71-
CartesianIndex(5, 2)
72-
CartesianIndex(5, 3)
73-
CartesianIndex(3, 4)
74-
CartesianIndex(3, 5)
75-
CartesianIndex(3, 6)
76-
```
77-
"""
78-
function augment_indices(inds::Vector{CartesianIndex{N}}, n) where {N}
79-
indices_wo_batch = [i.I[1:(end - 1)] for i in inds]
80-
return map(enumerate(repeat(indices_wo_batch; inner=n))) do (i, idx)
81-
CartesianIndex{N}(idx..., i)
82-
end
83-
end
84-
8511
"""
8612
NoiseAugmentation(analyzer, n)
8713
NoiseAugmentation(analyzer, n, std::Real)
@@ -104,38 +30,53 @@ struct NoiseAugmentation{A<:AbstractXAIMethod,D<:Sampleable,R<:AbstractRNG} <:
10430
n::Int
10531
distribution::D
10632
rng::R
107-
end
108-
function NoiseAugmentation(analyzer, n, distribution::Sampleable, rng=GLOBAL_RNG)
109-
return NoiseAugmentation(analyzer, n, distribution::Sampleable, rng)
33+
34+
function NoiseAugmentation(
35+
analyzer::A, n::Int, distribution::D, rng::R
36+
) where {A<:AbstractXAIMethod,D<:Sampleable,R<:AbstractRNG}
37+
n < 2 &&
38+
throw(ArgumentError("Number of noise samples `n` needs to be larger than one."))
39+
return new{A,D,R}(analyzer, n, distribution, rng)
40+
end
11041
end
11142
function NoiseAugmentation(analyzer, n, std::T=1.0f0, rng=GLOBAL_RNG) where {T<:Real}
11243
return NoiseAugmentation(analyzer, n, Normal(zero(T), std^2), rng)
11344
end
45+
function NoiseAugmentation(analyzer, n, distribution::Sampleable, rng=GLOBAL_RNG)
46+
return NoiseAugmentation(analyzer, n, distribution, rng)
47+
end
11448

11549
function call_analyzer(input, aug::NoiseAugmentation, ns::AbstractOutputSelector; kwargs...)
11650
# Regular forward pass of model
11751
output = aug.analyzer.model(input)
11852
output_indices = ns(output)
119-
120-
# Call regular analyzer on augmented batch
121-
augmented_input = add_noise(augment_batch_dim(input, aug.n), aug.distribution, aug.rng)
122-
augmented_indices = augment_indices(output_indices, aug.n)
123-
augmented_expl = aug.analyzer(augmented_input, AugmentationSelector(augmented_indices))
53+
output_selector = AugmentationSelector(output_indices)
54+
55+
# First augmentation
56+
input_aug = similar(input)
57+
input_aug = sample_noise!(input_aug, input, aug)
58+
expl_aug = aug.analyzer(input_aug, output_selector)
59+
sum_val = expl_aug.val
60+
61+
# Further augmentations
62+
for _ in 2:(aug.n)
63+
input_aug = sample_noise!(input_aug, input, aug)
64+
expl_aug = aug.analyzer(input_aug, output_selector)
65+
sum_val += expl_aug.val
66+
end
12467

12568
# Average explanation
69+
val = sum_val / aug.n
70+
12671
return Explanation(
127-
reduce_augmentation(augmented_expl.val, aug.n),
128-
input,
129-
output,
130-
output_indices,
131-
augmented_expl.analyzer,
132-
augmented_expl.heatmap,
133-
nothing,
72+
val, input, output, output_indices, expl_aug.analyzer, expl_aug.heatmap, nothing
13473
)
13574
end
13675

137-
function add_noise(A::AbstractArray{T}, distr::Distribution, rng::AbstractRNG) where {T}
138-
return A + T.(rand(rng, distr, size(A)))
76+
function sample_noise!(
77+
out::A, input::A, aug::NoiseAugmentation
78+
) where {T,A<:AbstractArray{T}}
79+
out .= input .+ rand(aug.rng, aug.distribution, size(input))
13980
end
14081

14182
"""
@@ -149,6 +90,13 @@ difference between the input and the reference input.
14990
struct InterpolationAugmentation{A<:AbstractXAIMethod} <: AbstractXAIMethod
15091
analyzer::A
15192
n::Int
93+
94+
function InterpolationAugmentation(analyzer::A, n::Int) where {A<:AbstractXAIMethod}
95+
n < 2 && throw(
96+
ArgumentError("Number of interpolation steps `n` needs to be larger than one."),
97+
)
98+
return new{A}(analyzer, n)
99+
end
152100
end
153101

154102
function call_analyzer(
@@ -160,57 +108,25 @@ function call_analyzer(
160108
# Regular forward pass of model
161109
output = aug.analyzer.model(input)
162110
output_indices = ns(output)
163-
164-
# Call regular analyzer on augmented batch
165-
augmented_input = interpolate_batch(input, input_ref, aug.n)
166-
augmented_indices = augment_indices(output_indices, aug.n)
167-
augmented_expl = aug.analyzer(augmented_input, AugmentationSelector(augmented_indices))
111+
output_selector = AugmentationSelector(output_indices)
112+
113+
# First augmentations
114+
input_aug = input_ref
115+
expl_aug = aug.analyzer(input_aug, output_selector)
116+
sum_val = expl_aug.val
117+
118+
# Further augmentations
119+
input_delta = (input - input_ref) / (aug.n - 1)
120+
for _ in 1:(aug.n)
121+
input_aug += input_delta
122+
expl_aug = aug.analyzer(input_aug, output_selector)
123+
sum_val += expl_aug.val
124+
end
168125

169126
# Average gradients and compute explanation
170-
expl = (input - input_ref) .* reduce_augmentation(augmented_expl.val, aug.n)
127+
val = (input - input_ref) .* sum_val / aug.n
171128

172129
return Explanation(
173-
expl,
174-
input,
175-
output,
176-
output_indices,
177-
augmented_expl.analyzer,
178-
augmented_expl.heatmap,
179-
nothing,
130+
val, input, output, output_indices, expl_aug.analyzer, expl_aug.heatmap, nothing
180131
)
181132
end
182-
183-
"""
184-
interpolate_batch(x, x0, nsamples)
185-
186-
Augment batch along batch dimension using linear interpolation between input `x` and a reference input `x0`.
187-
188-
## Example
189-
```julia-repl
190-
julia> x = Float16.(reshape(1:4, 2, 2))
191-
2×2 Matrix{Float16}:
192-
1.0 3.0
193-
2.0 4.0
194-
195-
julia> x0 = zero(x)
196-
2×2 Matrix{Float16}:
197-
0.0 0.0
198-
0.0 0.0
199-
200-
julia> interpolate_batch(x, x0, 5)
201-
2×10 Matrix{Float16}:
202-
0.0 0.25 0.5 0.75 1.0 0.0 0.75 1.5 2.25 3.0
203-
0.0 0.5 1.0 1.5 2.0 0.0 1.0 2.0 3.0 4.0
204-
```
205-
"""
206-
function interpolate_batch(
207-
x::AbstractArray{T,N}, x0::AbstractArray{T,N}, nsamples
208-
) where {T,N}
209-
in_size = size(x)
210-
outs = similar(x, (in_size[1:(end - 1)]..., in_size[end] * nsamples))
211-
colons = ntuple(Returns(:), N - 1)
212-
for (i, t) in enumerate(range(zero(T), oneunit(T); length=nsamples))
213-
outs[colons..., i:nsamples:end] .= x0 + t * (x - x0)
214-
end
215-
return outs
216-
end
0 Bytes
Binary file not shown.

test/runtests.jl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,6 @@ using JET
2121
end
2222
end
2323

24-
@testset "Input augmentation" begin
25-
@info "Testing input augmentation..."
26-
include("test_input_augmentation.jl")
27-
end
2824
@testset "CNN" begin
2925
@info "Testing analyzers on CNN..."
3026
include("test_cnn.jl")

test/test_input_augmentation.jl

Lines changed: 0 additions & 52 deletions
This file was deleted.

0 commit comments

Comments
 (0)