Skip to content

Commit 33c24a7

Browse files
committed
Fix default noise level for NoiseAugmentation
Breaking: rename kwarg to `std`.
1 parent 61cf3f4 commit 33c24a7

File tree

1 file changed

+16
-7
lines changed

1 file changed

+16
-7
lines changed

src/input_augmentation.jl

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -83,11 +83,20 @@ function augment_indices(inds::Vector{CartesianIndex{N}}, n) where {N}
8383
end
8484

8585
"""
86-
NoiseAugmentation(analyzer, n, [std=1, rng=GLOBAL_RNG])
87-
NoiseAugmentation(analyzer, n, distribution, [rng=GLOBAL_RNG])
86+
NoiseAugmentation(analyzer, n)
87+
NoiseAugmentation(analyzer, n, std::Real)
88+
NoiseAugmentation(analyzer, n, distribution::Sampleable)
8889
89-
A wrapper around analyzers that augments the input with `n` samples of additive noise sampled from `distribution`.
90+
A wrapper around analyzers that augments the input with `n` samples of additive noise sampled from a scalar `distribution`.
9091
This input augmentation is then averaged to return an `Explanation`.
92+
93+
Defaults to the normal distribution `Normal(0, std^2)` with `std=1`.
94+
For optimal results, $REF_SMILKOV_SMOOTHGRAD recommends setting `std` between 10% and 20% of the input range of every sample,
95+
e.g. `std = 0.1 * (maximum(input) - minimum(input))`.
96+
97+
## Keyword arguments
98+
- `rng::AbstractRNG`: Specify the random number generator that is used to sample noise from the `distribution`.
99+
Defaults to `GLOBAL_RNG`.
91100
"""
92101
struct NoiseAugmentation{A<:AbstractXAIMethod,D<:Sampleable,R<:AbstractRNG} <:
93102
AbstractXAIMethod
@@ -96,11 +105,11 @@ struct NoiseAugmentation{A<:AbstractXAIMethod,D<:Sampleable,R<:AbstractRNG} <:
96105
distribution::D
97106
rng::R
98107
end
99-
function NoiseAugmentation(analyzer, n, distr::Sampleable, rng=GLOBAL_RNG)
100-
return NoiseAugmentation(analyzer, n, distr::Sampleable, rng)
108+
function NoiseAugmentation(analyzer, n, distribution::Sampleable, rng=GLOBAL_RNG)
109+
return NoiseAugmentation(analyzer, n, distribution::Sampleable, rng)
101110
end
102-
function NoiseAugmentation(analyzer, n, σ::Real=0.1f0, args...)
103-
return NoiseAugmentation(analyzer, n, Normal(0.0f0, Float32(σ)^2), args...)
111+
function NoiseAugmentation(analyzer, n, std::T=1.0f0, rng=GLOBAL_RNG) where {T<:Real}
112+
return NoiseAugmentation(analyzer, n, Normal(zero(T), std^2), rng)
104113
end
105114

106115
function call_analyzer(input, aug::NoiseAugmentation, ns::AbstractOutputSelector; kwargs...)

0 commit comments

Comments
 (0)