Skip to content

Commit af5a2d8

Browse files
authored
Add Integrated Gradients analyzer (#65)
* Reorganize input_augmentation.jl * Rename `InputAugmentation` to `NoiseAugmentation` * Add `InterpolationAugmentation ` and tests * Add `IntegratedGradients` and tests * Update docs and readme * Add benchmark for `IntegratedGradients` * Update deprecated MLDatasets calls to API of `v0.7` * Allow any type of `Sampleable` in `NoiseAugmentation`
1 parent f1b89ab commit af5a2d8

14 files changed

+164
-63
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ using BSON: @load
3333
model = strip_softmax(model)
3434

3535
# Load input
36-
x, _ = MNIST.testdata(Float32, 10)
36+
x, _ = MNIST(Float32, :test)[10]
3737
input = reshape(x, 28, 28, 1, :) # reshape to WHCN format
3838

3939
# Run XAI method
@@ -55,6 +55,7 @@ Currently, the following analyzers are implemented:
5555
├── Gradient
5656
├── InputTimesGradient
5757
├── SmoothGrad
58+
├── IntegratedGradients
5859
└── LRP
5960
├── LRPZero
6061
├── LRPEpsilon
@@ -66,7 +67,6 @@ Individual LRP rules like `ZeroRule`, `EpsilonRule`, `GammaRule` and `ZBoxRule`
6667

6768
## Roadmap
6869
In the future, we would like to include:
69-
- [Integrated Gradients](https://arxiv.org/abs/1703.01365)
7070
- [PatternNet](https://arxiv.org/abs/1705.05598)
7171
- [DeepLift](https://arxiv.org/abs/1704.02685)
7272
- [LIME](https://arxiv.org/abs/1602.04938)

benchmark/benchmarks.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ algs = Dict(
2222
"LRPZero" => LRPZero,
2323
"LRPCustom" => LRPCustom, #modifies weights
2424
"SmoothGrad" => model -> SmoothGrad(model, 10),
25+
"IntegratedGradients" => model -> IntegratedGradients(model, 10),
2526
)
2627

2728
# Define benchmark

docs/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,7 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
55
ExplainableAI = "4f1bc3e1-d60d-4ed0-9367-9bdff9846d3b"
66
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
77
ImageCore = "a09fc81d-aa75-5fe9-8630-4744c3626534"
8+
ImageIO = "82e4d734-157c-48bb-816b-45c225c6df19"
9+
ImageShow = "4e3cecfd-b093-5904-9786-8bbb286a6a31"
810
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
911
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"

docs/literate/advanced_lrp.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ using BSON
1515
model = BSON.load("../model.bson", @__MODULE__)[:model]
1616

1717
index = 10
18-
x, y = MNIST.testdata(Float32, index)
18+
x, _ = MNIST(Float32, :test)[10]
1919
input = reshape(x, 28, 28, 1, :);
2020

2121
# ## Custom LRP composites
@@ -200,7 +200,7 @@ analyzer = LRPZero(model)
200200
# They in-place modify a pre-allocated array of the input relevance `Rₖ`
201201
# based on the input activation `aₖ` and output relevance `Rₖ₊₁`.
202202
#
203-
# Calling `analyze` then applies a foward-pass of the model, keeping track of
203+
# Calling `analyze` then applies a forward-pass of the model, keeping track of
204204
# the activations `aₖ` for each layer `k`.
205205
# The relevance `Rₖ₊₁` is then set to the output neuron activation and the rules are applied
206206
# in a backward-pass over the model layers and previous activations.

docs/literate/example.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,13 @@ model = BSON.load("../model.bson", @__MODULE__)[:model]
2929
# We use MLDatasets to load a single image from the MNIST dataset:
3030
using MLDatasets
3131
using ImageCore
32+
using ImageIO
33+
using ImageShow
3234

3335
index = 10
34-
x, y = MNIST.testdata(Float32, index)
36+
x, _ = MNIST(Float32, :test)[10]
3537

36-
MNIST.convert2image(x)
38+
convert2image(MNIST, x)
3739

3840
# By convention in Flux.jl, this input needs to be resized to WHCN format by adding a color channel and batch dimensions.
3941
input = reshape(x, 28, 28, 1, :);
@@ -82,7 +84,7 @@ heatmap(input, analyzer, 5)
8284
# ## Input batches
8385
# ExplainableAI also supports input batches:
8486
batchsize = 100
85-
xs, _ = MNIST.testdata(Float32, 1:batchsize)
87+
xs, _ = MNIST(Float32, :test)[1:batchsize]
8688
batch = reshape(xs, 28, 28, 1, :) # reshape to WHCN format
8789
expl_batch = analyze(batch, analyzer);
8890

@@ -106,6 +108,7 @@ mosaic(heatmap(batch, analyzer, 1); nrow=10)
106108
# ├── Gradient
107109
# ├── InputTimesGradient
108110
# ├── SmoothGrad
111+
# ├── IntegratedGradients
109112
# └── LRP
110113
# ├── LRPZero
111114
# ├── LRPEpsilon

docs/src/api.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,13 @@ LRP
1111
Gradient
1212
InputTimesGradient
1313
SmoothGrad
14+
IntegratedGradients
1415
```
1516

16-
`SmoothGrad` is a special case of `InputAugmentation`, which can be applied as a wrapper to any analyzer:
17+
`SmoothGrad` and `IntegratedGradients` are special cases of the input augmentation wrappers `NoiseAugmentation` and `InterpolationAugmentation`, which can be applied as a wrapper to any analyzer:
1718
```@docs
18-
InputAugmentation
19+
NoiseAugmentation
20+
InterpolationAugmentation
1921
```
2022

2123
# LRP

src/ExplainableAI.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module ExplainableAI
22

33
using Base.Iterators
44
using LinearAlgebra
5-
using Distributions
5+
using Distributions: Distribution, Sampleable, Normal
66
using Random: AbstractRNG, GLOBAL_RNG
77
using Flux
88
using Zygote
@@ -34,7 +34,8 @@ export analyze
3434
# Analyzers
3535
export AbstractXAIMethod
3636
export Gradient, InputTimesGradient
37-
export InputAugmentation, SmoothGrad
37+
export NoiseAugmentation, SmoothGrad
38+
export InterpolationAugmentation, IntegratedGradients
3839
export LRP, LRPZero, LRPEpsilon, LRPGamma
3940

4041
# LRP rules

src/gradient.jl

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,4 +61,15 @@ in a neighborhood of the input, typically by adding Gaussian noise with mean 0.
6161
# References
6262
[1] Smilkov et al., SmoothGrad: removing noise by adding noise
6363
"""
64-
SmoothGrad(model, n=50, args...) = InputAugmentation(Gradient(model), n, args...)
64+
SmoothGrad(model, n=50, args...) = NoiseAugmentation(Gradient(model), n, args...)
65+
66+
"""
67+
IntegratedGradients(analyzer, [n=50])
68+
IntegratedGradients(analyzer, [n=50])
69+
70+
Analyze model by using the Integrated Gradients method.
71+
72+
# References
73+
[1] Sundararajan et al., Axiomatic Attribution for Deep Networks
74+
"""
75+
IntegratedGradients(model, n=50) = InterpolationAugmentation(Gradient(model), n)

src/input_augmentation.jl

Lines changed: 117 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,3 @@
1-
"""
2-
InputAugmentation(analyzer, n, [std=1, rng=GLOBAL_RNG])
3-
InputAugmentation(analyzer, n, distribution, [rng=GLOBAL_RNG])
4-
5-
A wrapper around analyzers that augments the input with `n` samples of additive noise sampled from `distribution`.
6-
This input augmentation is then averaged to return an `Explanation`.
7-
"""
8-
struct InputAugmentation{A<:AbstractXAIMethod,D<:Distribution,R<:AbstractRNG} <:
9-
AbstractXAIMethod
10-
analyzer::A
11-
n::Integer
12-
distribution::D
13-
rng::R
14-
end
15-
function InputAugmentation(analyzer, n, distr, rng=GLOBAL_RNG)
16-
return InputAugmentation(analyzer, n, distr, rng)
17-
end
18-
function InputAugmentation(analyzer, n, σ::Real=0.1f0, args...)
19-
return InputAugmentation(analyzer, n, Normal(0.0f0, Float32(σ)^2), args...)
20-
end
21-
22-
function (aug::InputAugmentation)(input, ns::AbstractNeuronSelector)
23-
# Regular forward pass of model
24-
output = aug.analyzer.model(input)
25-
output_indices = ns(output)
26-
27-
# Call regular analyzer on augmented batch
28-
augmented_input = add_noise(augment_batch_dim(input, aug.n), aug.distribution, aug.rng)
29-
augmented_indices = augment_indices(output_indices, aug.n)
30-
augmented_expl = aug.analyzer(augmented_input, AugmentationSelector(augmented_indices))
31-
32-
# Average explanation
33-
return Explanation(
34-
reduce_augmentation(augmented_expl.attribution, aug.n),
35-
output,
36-
output_indices,
37-
augmented_expl.analyzer,
38-
Nothing,
39-
)
40-
end
41-
42-
function add_noise(A::AbstractArray{T}, distr::Distribution, rng::AbstractRNG) where {T}
43-
return A + T.(rand(rng, distr, size(A)))
44-
end
45-
461
"""
472
augment_batch_dim(input, n)
483
@@ -80,13 +35,13 @@ function reduce_augmentation(input::AbstractArray{T,N}, n) where {T<:AbstractFlo
8035
out = similar(input, eltype(input), out_size)
8136

8237
axs = axes(input, N)
83-
inds_before_N = ntuple(Returns(:), N - 1)
38+
colons = ntuple(Returns(:), N - 1)
8439
for (i, ax) in enumerate(first(axs):n:last(axs))
85-
view(out, inds_before_N..., i) .=
86-
sum(view(input, inds_before_N..., ax:(ax + n - 1)); dims=N) / n
40+
view(out, colons..., i) .= sum(view(input, colons..., ax:(ax + n - 1)); dims=N) / n
8741
end
8842
return out
8943
end
44+
9045
"""
9146
augment_indices(indices, n)
9247
@@ -115,3 +70,117 @@ function augment_indices(inds::Vector{CartesianIndex{N}}, n) where {N}
11570
CartesianIndex{N}(idx..., i)
11671
end
11772
end
73+
74+
"""
75+
NoiseAugmentation(analyzer, n, [std=1, rng=GLOBAL_RNG])
76+
NoiseAugmentation(analyzer, n, distribution, [rng=GLOBAL_RNG])
77+
78+
A wrapper around analyzers that augments the input with `n` samples of additive noise sampled from `distribution`.
79+
This input augmentation is then averaged to return an `Explanation`.
80+
"""
81+
struct NoiseAugmentation{A<:AbstractXAIMethod,D<:Sampleable,R<:AbstractRNG} <:
82+
AbstractXAIMethod
83+
analyzer::A
84+
n::Int
85+
distribution::D
86+
rng::R
87+
end
88+
function NoiseAugmentation(analyzer, n, distr::Sampleable, rng=GLOBAL_RNG)
89+
return NoiseAugmentation(analyzer, n, distr::Sampleable, rng)
90+
end
91+
function NoiseAugmentation(analyzer, n, σ::Real=0.1f0, args...)
92+
return NoiseAugmentation(analyzer, n, Normal(0.0f0, Float32(σ)^2), args...)
93+
end
94+
95+
function (aug::NoiseAugmentation)(input, ns::AbstractNeuronSelector)
96+
# Regular forward pass of model
97+
output = aug.analyzer.model(input)
98+
output_indices = ns(output)
99+
100+
# Call regular analyzer on augmented batch
101+
augmented_input = add_noise(augment_batch_dim(input, aug.n), aug.distribution, aug.rng)
102+
augmented_indices = augment_indices(output_indices, aug.n)
103+
augmented_expl = aug.analyzer(augmented_input, AugmentationSelector(augmented_indices))
104+
105+
# Average explanation
106+
return Explanation(
107+
reduce_augmentation(augmented_expl.attribution, aug.n),
108+
output,
109+
output_indices,
110+
augmented_expl.analyzer,
111+
Nothing,
112+
)
113+
end
114+
115+
function add_noise(A::AbstractArray{T}, distr::Distribution, rng::AbstractRNG) where {T}
116+
return A + T.(rand(rng, distr, size(A)))
117+
end
118+
119+
"""
120+
InterpolationAugmentation(model, [n=50])
121+
122+
A wrapper around analyzers that augments the input with `n` steps of linear interpolation
123+
between the input and a reference input (typically `zero(input)`).
124+
The gradients w.r.t. this augmented input are then averaged and multiplied with the
125+
difference between the input and the reference input.
126+
"""
127+
struct InterpolationAugmentation{A<:AbstractXAIMethod} <: AbstractXAIMethod
128+
analyzer::A
129+
n::Int
130+
end
131+
132+
function (aug::InterpolationAugmentation)(
133+
input, ns::AbstractNeuronSelector, input_ref=zero(input)
134+
)
135+
size(input) != size(input_ref) &&
136+
throw(ArgumentError("Input reference size doesn't match input size."))
137+
138+
# Regular forward pass of model
139+
output = aug.analyzer.model(input)
140+
output_indices = ns(output)
141+
142+
# Call regular analyzer on augmented batch
143+
augmented_input = interpolate_batch(input, input_ref, aug.n)
144+
augmented_indices = augment_indices(output_indices, aug.n)
145+
augmented_expl = aug.analyzer(augmented_input, AugmentationSelector(augmented_indices))
146+
147+
# Average gradients and compute explanation
148+
expl = (input - input_ref) .* reduce_augmentation(augmented_expl.attribution, aug.n)
149+
150+
return Explanation(expl, output, output_indices, augmented_expl.analyzer, Nothing)
151+
end
152+
153+
"""
154+
interpolate_batch(x, x0, nsamples)
155+
156+
Augment batch along batch dimension using linear interpolation between input `x` and a reference input `x0`.
157+
158+
## Example
159+
```julia-repl
160+
julia> x = Float16.(reshape(1:4, 2, 2))
161+
2×2 Matrix{Float16}:
162+
1.0 3.0
163+
2.0 4.0
164+
165+
julia> x0 = zero(x)
166+
2×2 Matrix{Float16}:
167+
0.0 0.0
168+
0.0 0.0
169+
170+
julia> interpolate_batch(x, x0, 5)
171+
2×10 Matrix{Float16}:
172+
0.0 0.25 0.5 0.75 1.0 0.0 0.75 1.5 2.25 3.0
173+
0.0 0.5 1.0 1.5 2.0 0.0 1.0 2.0 3.0 4.0
174+
```
175+
"""
176+
function interpolate_batch(
177+
x::AbstractArray{T,N}, x0::AbstractArray{T,N}, nsamples
178+
) where {T,N}
179+
in_size = size(x)
180+
outs = similar(x, (in_size[1:(end - 1)]..., in_size[end] * nsamples))
181+
colons = ntuple(Returns(:), N - 1)
182+
for (i, t) in enumerate(range(zero(T), oneunit(T); length=nsamples))
183+
outs[colons..., i:nsamples:end] .= x0 + t * (x - x0)
184+
end
185+
return outs
186+
end
589 KB
Binary file not shown.

0 commit comments

Comments
 (0)