|
1 |
| -""" |
2 |
| - InputAugmentation(analyzer, n, [std=1, rng=GLOBAL_RNG]) |
3 |
| - InputAugmentation(analyzer, n, distribution, [rng=GLOBAL_RNG]) |
4 |
| -
|
5 |
| -A wrapper around analyzers that augments the input with `n` samples of additive noise sampled from `distribution`. |
6 |
| -This input augmentation is then averaged to return an `Explanation`. |
7 |
| -""" |
8 |
| -struct InputAugmentation{A<:AbstractXAIMethod,D<:Distribution,R<:AbstractRNG} <: |
9 |
| - AbstractXAIMethod |
10 |
| - analyzer::A |
11 |
| - n::Integer |
12 |
| - distribution::D |
13 |
| - rng::R |
14 |
| -end |
15 |
| -function InputAugmentation(analyzer, n, distr, rng=GLOBAL_RNG) |
16 |
| - return InputAugmentation(analyzer, n, distr, rng) |
17 |
| -end |
18 |
| -function InputAugmentation(analyzer, n, σ::Real=0.1f0, args...) |
19 |
| - return InputAugmentation(analyzer, n, Normal(0.0f0, Float32(σ)^2), args...) |
20 |
| -end |
21 |
| - |
22 |
| -function (aug::InputAugmentation)(input, ns::AbstractNeuronSelector) |
23 |
| - # Regular forward pass of model |
24 |
| - output = aug.analyzer.model(input) |
25 |
| - output_indices = ns(output) |
26 |
| - |
27 |
| - # Call regular analyzer on augmented batch |
28 |
| - augmented_input = add_noise(augment_batch_dim(input, aug.n), aug.distribution, aug.rng) |
29 |
| - augmented_indices = augment_indices(output_indices, aug.n) |
30 |
| - augmented_expl = aug.analyzer(augmented_input, AugmentationSelector(augmented_indices)) |
31 |
| - |
32 |
| - # Average explanation |
33 |
| - return Explanation( |
34 |
| - reduce_augmentation(augmented_expl.attribution, aug.n), |
35 |
| - output, |
36 |
| - output_indices, |
37 |
| - augmented_expl.analyzer, |
38 |
| - Nothing, |
39 |
| - ) |
40 |
| -end |
41 |
| - |
42 |
| -function add_noise(A::AbstractArray{T}, distr::Distribution, rng::AbstractRNG) where {T} |
43 |
| - return A + T.(rand(rng, distr, size(A))) |
44 |
| -end |
45 |
| - |
46 | 1 | """
|
47 | 2 | augment_batch_dim(input, n)
|
48 | 3 |
|
@@ -80,13 +35,13 @@ function reduce_augmentation(input::AbstractArray{T,N}, n) where {T<:AbstractFlo
|
80 | 35 | out = similar(input, eltype(input), out_size)
|
81 | 36 |
|
82 | 37 | axs = axes(input, N)
|
83 |
| - inds_before_N = ntuple(Returns(:), N - 1) |
| 38 | + colons = ntuple(Returns(:), N - 1) |
84 | 39 | for (i, ax) in enumerate(first(axs):n:last(axs))
|
85 |
| - view(out, inds_before_N..., i) .= |
86 |
| - sum(view(input, inds_before_N..., ax:(ax + n - 1)); dims=N) / n |
| 40 | + view(out, colons..., i) .= sum(view(input, colons..., ax:(ax + n - 1)); dims=N) / n |
87 | 41 | end
|
88 | 42 | return out
|
89 | 43 | end
|
| 44 | + |
90 | 45 | """
|
91 | 46 | augment_indices(indices, n)
|
92 | 47 |
|
@@ -115,3 +70,117 @@ function augment_indices(inds::Vector{CartesianIndex{N}}, n) where {N}
|
115 | 70 | CartesianIndex{N}(idx..., i)
|
116 | 71 | end
|
117 | 72 | end
|
| 73 | + |
| 74 | +""" |
| 75 | + NoiseAugmentation(analyzer, n, [std=1, rng=GLOBAL_RNG]) |
| 76 | + NoiseAugmentation(analyzer, n, distribution, [rng=GLOBAL_RNG]) |
| 77 | +
|
| 78 | +A wrapper around analyzers that augments the input with `n` samples of additive noise sampled from `distribution`. |
| 79 | +This input augmentation is then averaged to return an `Explanation`. |
| 80 | +""" |
| 81 | +struct NoiseAugmentation{A<:AbstractXAIMethod,D<:Sampleable,R<:AbstractRNG} <: |
| 82 | + AbstractXAIMethod |
| 83 | + analyzer::A |
| 84 | + n::Int |
| 85 | + distribution::D |
| 86 | + rng::R |
| 87 | +end |
| 88 | +function NoiseAugmentation(analyzer, n, distr::Sampleable, rng=GLOBAL_RNG) |
| 89 | + return NoiseAugmentation(analyzer, n, distr::Sampleable, rng) |
| 90 | +end |
| 91 | +function NoiseAugmentation(analyzer, n, σ::Real=0.1f0, args...) |
| 92 | + return NoiseAugmentation(analyzer, n, Normal(0.0f0, Float32(σ)^2), args...) |
| 93 | +end |
| 94 | + |
| 95 | +function (aug::NoiseAugmentation)(input, ns::AbstractNeuronSelector) |
| 96 | + # Regular forward pass of model |
| 97 | + output = aug.analyzer.model(input) |
| 98 | + output_indices = ns(output) |
| 99 | + |
| 100 | + # Call regular analyzer on augmented batch |
| 101 | + augmented_input = add_noise(augment_batch_dim(input, aug.n), aug.distribution, aug.rng) |
| 102 | + augmented_indices = augment_indices(output_indices, aug.n) |
| 103 | + augmented_expl = aug.analyzer(augmented_input, AugmentationSelector(augmented_indices)) |
| 104 | + |
| 105 | + # Average explanation |
| 106 | + return Explanation( |
| 107 | + reduce_augmentation(augmented_expl.attribution, aug.n), |
| 108 | + output, |
| 109 | + output_indices, |
| 110 | + augmented_expl.analyzer, |
| 111 | + Nothing, |
| 112 | + ) |
| 113 | +end |
| 114 | + |
| 115 | +function add_noise(A::AbstractArray{T}, distr::Distribution, rng::AbstractRNG) where {T} |
| 116 | + return A + T.(rand(rng, distr, size(A))) |
| 117 | +end |
| 118 | + |
| 119 | +""" |
| 120 | + InterpolationAugmentation(model, [n=50]) |
| 121 | +
|
| 122 | +A wrapper around analyzers that augments the input with `n` steps of linear interpolation |
| 123 | +between the input and a reference input (typically `zero(input)`). |
| 124 | +The gradients w.r.t. this augmented input are then averaged and multiplied with the |
| 125 | +difference between the input and the reference input. |
| 126 | +""" |
| 127 | +struct InterpolationAugmentation{A<:AbstractXAIMethod} <: AbstractXAIMethod |
| 128 | + analyzer::A |
| 129 | + n::Int |
| 130 | +end |
| 131 | + |
| 132 | +function (aug::InterpolationAugmentation)( |
| 133 | + input, ns::AbstractNeuronSelector, input_ref=zero(input) |
| 134 | +) |
| 135 | + size(input) != size(input_ref) && |
| 136 | + throw(ArgumentError("Input reference size doesn't match input size.")) |
| 137 | + |
| 138 | + # Regular forward pass of model |
| 139 | + output = aug.analyzer.model(input) |
| 140 | + output_indices = ns(output) |
| 141 | + |
| 142 | + # Call regular analyzer on augmented batch |
| 143 | + augmented_input = interpolate_batch(input, input_ref, aug.n) |
| 144 | + augmented_indices = augment_indices(output_indices, aug.n) |
| 145 | + augmented_expl = aug.analyzer(augmented_input, AugmentationSelector(augmented_indices)) |
| 146 | + |
| 147 | + # Average gradients and compute explanation |
| 148 | + expl = (input - input_ref) .* reduce_augmentation(augmented_expl.attribution, aug.n) |
| 149 | + |
| 150 | + return Explanation(expl, output, output_indices, augmented_expl.analyzer, Nothing) |
| 151 | +end |
| 152 | + |
| 153 | +""" |
| 154 | + interpolate_batch(x, x0, nsamples) |
| 155 | +
|
| 156 | +Augment batch along batch dimension using linear interpolation between input `x` and a reference input `x0`. |
| 157 | +
|
| 158 | +## Example |
| 159 | +```julia-repl |
| 160 | +julia> x = Float16.(reshape(1:4, 2, 2)) |
| 161 | +2×2 Matrix{Float16}: |
| 162 | + 1.0 3.0 |
| 163 | + 2.0 4.0 |
| 164 | +
|
| 165 | +julia> x0 = zero(x) |
| 166 | +2×2 Matrix{Float16}: |
| 167 | + 0.0 0.0 |
| 168 | + 0.0 0.0 |
| 169 | +
|
| 170 | +julia> interpolate_batch(x, x0, 5) |
| 171 | +2×10 Matrix{Float16}: |
| 172 | + 0.0 0.25 0.5 0.75 1.0 0.0 0.75 1.5 2.25 3.0 |
| 173 | + 0.0 0.5 1.0 1.5 2.0 0.0 1.0 2.0 3.0 4.0 |
| 174 | +``` |
| 175 | +""" |
| 176 | +function interpolate_batch( |
| 177 | + x::AbstractArray{T,N}, x0::AbstractArray{T,N}, nsamples |
| 178 | +) where {T,N} |
| 179 | + in_size = size(x) |
| 180 | + outs = similar(x, (in_size[1:(end - 1)]..., in_size[end] * nsamples)) |
| 181 | + colons = ntuple(Returns(:), N - 1) |
| 182 | + for (i, t) in enumerate(range(zero(T), oneunit(T); length=nsamples)) |
| 183 | + outs[colons..., i:nsamples:end] .= x0 + t * (x - x0) |
| 184 | + end |
| 185 | + return outs |
| 186 | +end |
0 commit comments