|
| 1 | +""" |
| 2 | + InputAugmentation(analyzer, n, [std=1, rng=GLOBAL_RNG]) |
| 3 | + InputAugmentation(analyzer, n, distribution, [rng=GLOBAL_RNG]) |
| 4 | +
|
| 5 | +A wrapper around analyzers that augments the input with `n` samples of additive noise sampled from `distribution`. |
| 6 | +This input augmentation is then averaged to return an `Explanation`. |
| 7 | +""" |
| 8 | +struct InputAugmentation{A<:AbstractXAIMethod,D<:Distribution,R<:AbstractRNG} <: |
| 9 | + AbstractXAIMethod |
| 10 | + analyzer::A |
| 11 | + n::Integer |
| 12 | + distribution::D |
| 13 | + rng::R |
| 14 | +end |
| 15 | +function InputAugmentation(analyzer, n, distr, rng=GLOBAL_RNG) |
| 16 | + return InputAugmentation(analyzer, n, distr, rng) |
| 17 | +end |
| 18 | +function InputAugmentation(analyzer, n, σ::Real=0.1f0, args...) |
| 19 | + return InputAugmentation(analyzer, n, Normal(0.0f0, Float32(σ)^2), args...) |
| 20 | +end |
| 21 | + |
| 22 | +function (aug::InputAugmentation)(input, ns::AbstractNeuronSelector) |
| 23 | + # Regular forward pass of model |
| 24 | + output = aug.analyzer.model(input) |
| 25 | + output_indices = ns(output) |
| 26 | + |
| 27 | + # Call regular analyzer on augmented batch |
| 28 | + augmented_input = add_noise(augment_batch_dim(input, aug.n), aug.distribution, aug.rng) |
| 29 | + augmented_indices = augment_indices(output_indices, aug.n) |
| 30 | + augmented_expl = aug.analyzer(augmented_input, AugmentationSelector(augmented_indices)) |
| 31 | + |
| 32 | + # Average explanation |
| 33 | + return Explanation( |
| 34 | + reduce_augmentation(augmented_expl.attribution, aug.n), |
| 35 | + output, |
| 36 | + output_indices, |
| 37 | + augmented_expl.analyzer, |
| 38 | + Nothing, |
| 39 | + ) |
| 40 | +end |
| 41 | + |
| 42 | +function add_noise(A::AbstractArray{T}, distr::Distribution, rng::AbstractRNG) where {T} |
| 43 | + return A + T.(rand(rng, distr, size(A))) |
| 44 | +end |
| 45 | + |
| 46 | +""" |
| 47 | + augment_batch_dim(input, n) |
| 48 | +
|
| 49 | +Repeat each sample in input batch n-times along batch dimension. |
| 50 | +This turns arrays of size `(..., B)` into arrays of size `(..., B*n)`. |
| 51 | +
|
| 52 | +## Example |
| 53 | +```julia-repl |
| 54 | +julia> A = [1 2; 3 4] |
| 55 | +2×2 Matrix{Int64}: |
| 56 | + 1 2 |
| 57 | + 3 4 |
| 58 | +
|
| 59 | +julia> augment_batch_dim(A, 3) |
| 60 | +2×6 Matrix{Int64}: |
| 61 | + 1 1 1 2 2 2 |
| 62 | + 3 3 3 4 4 4 |
| 63 | +``` |
| 64 | +""" |
| 65 | +function augment_batch_dim(input::AbstractArray{T,N}, n) where {T,N} |
| 66 | + return repeat(input; inner=(ntuple(_ -> 1, Val(N - 1))..., n)) |
| 67 | +end |
| 68 | + |
| 69 | +""" |
| 70 | + reduce_augmentation(augmented_input, n) |
| 71 | +
|
| 72 | +Reduce augmented input batch by averaging the explanation for each augmented sample. |
| 73 | +""" |
| 74 | +function reduce_augmentation(input::AbstractArray{T,N}, n) where {T<:AbstractFloat,N} |
| 75 | + return cat( |
| 76 | + ( |
| 77 | + Iterators.map(1:n:size(input, N)) do i |
| 78 | + augmentation_range = ntuple(_ -> :, Val(N - 1))..., i:(i + n - 1) |
| 79 | + sum(view(input, augmentation_range...); dims=N) / n |
| 80 | + end |
| 81 | + )...; dims=N |
| 82 | + )::Array{T,N} |
| 83 | +end |
| 84 | +""" |
| 85 | + augment_indices(indices, n) |
| 86 | +
|
| 87 | +Strip batch indices and return inidices for batch augmented by n samples. |
| 88 | +
|
| 89 | +## Example |
| 90 | +```julia-repl |
| 91 | +julia> inds = [CartesianIndex(5,1), CartesianIndex(3,2)] |
| 92 | +2-element Vector{CartesianIndex{2}}: |
| 93 | + CartesianIndex(5, 1) |
| 94 | + CartesianIndex(3, 2) |
| 95 | +
|
| 96 | +julia> augment_indices(inds, 3) |
| 97 | +6-element Vector{CartesianIndex{2}}: |
| 98 | + CartesianIndex(5, 1) |
| 99 | + CartesianIndex(5, 2) |
| 100 | + CartesianIndex(5, 3) |
| 101 | + CartesianIndex(3, 4) |
| 102 | + CartesianIndex(3, 5) |
| 103 | + CartesianIndex(3, 6) |
| 104 | +``` |
| 105 | +""" |
| 106 | +function augment_indices(inds::Vector{CartesianIndex{N}}, n) where {N} |
| 107 | + indices_wo_batch = [i.I[1:(end - 1)] for i in inds] |
| 108 | + return map(enumerate(repeat(indices_wo_batch; inner=n))) do (i, idx) |
| 109 | + CartesianIndex{N}(idx..., i) |
| 110 | + end |
| 111 | +end |
0 commit comments