Skip to content

Commit 840eee3

Browse files
authored
Add compatibility checks for LRP rule & layer combinations (#75)
Changes to functionality: * Add `check_compat` mechanism * Modify all layers that have weights and biases * Allow `ZBoxRule` on all layers, but throw layer compatibility error * Remove named LRP constructors Changes to tests, docs and benchmarks: * Update rule tests to check for compat errors * Remove `TestWrapper` tests and benchmarks * Remove refs for rule & layer combinations deprecated by `check_compat` * Update `LRPCustom` preset in docs and VGG tests
1 parent afcf4ee commit 840eee3

36 files changed

+179
-200
lines changed

benchmark/benchmarks.jl

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ end
2525
algs = Dict(
2626
"Gradient" => Gradient,
2727
"InputTimesGradient" => InputTimesGradient,
28-
"LRPZero" => LRPZero,
28+
"LRPZero" => LRP,
2929
"LRPCustom" => LRPCustom, #modifies weights
3030
"SmoothGrad" => model -> SmoothGrad(model, 10),
3131
"IntegratedGradients" => model -> IntegratedGradients(model, 10),
@@ -46,17 +46,6 @@ for (name, alg) in algs
4646
SUITE["VGG"][name]["analyze"] = @benchmarkable analyze($(img), $(analyzer))
4747
end
4848

49-
# Rules benchmarks – use wrapper to trigger AD fallback
50-
struct TestWrapper{T}
51-
layer::T
52-
end
53-
(w::TestWrapper)(x) = w.layer(x)
54-
modify_layer!(rule::R, w::TestWrapper) where {R} = modify_layer!(rule, w.layer)
55-
get_layer_resetter(rule::R, w::TestWrapper) where {R} = get_layer_resetter(rule, w.layer)
56-
get_layer_resetter(::ZeroRule, w::TestWrapper) = Returns(nothing)
57-
get_layer_resetter(::EpsilonRule, w::TestWrapper) = Returns(nothing)
58-
lrp!(Rₖ, rule::ZBoxRule, w::TestWrapper, aₖ, Rₖ₊₁) = lrp!(Rₖ, rule, w.layer, aₖ, Rₖ₊₁)
59-
6049
# generate input for conv layers
6150
insize = (64, 64, 3, 1)
6251
in_dense = 500
@@ -67,8 +56,6 @@ layers = Dict(
6756
"MaxPool" => (MaxPool((3, 3); pad=0), aₖ),
6857
"Conv" => (Conv((3, 3), 3 => 2), aₖ),
6958
"Dense" => (Dense(in_dense, out_dense, relu), randn(T, in_dense, 1)),
70-
"WrappedDense" =>
71-
(TestWrapper(Dense(in_dense, out_dense, relu)), randn(T, in_dense, 1)),
7259
)
7360
rules = Dict(
7461
"ZeroRule" => ZeroRule(),

docs/literate/advanced_lrp.jl

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,13 @@ input = reshape(x, 28, 28, 1, :);
2424
# For this purpose, we create an array of rules that matches the length of the Flux chain:
2525
rules = [
2626
ZBoxRule(0.0f0, 1.0f0),
27-
GammaRule(),
28-
GammaRule(),
29-
EpsilonRule(),
3027
EpsilonRule(),
28+
GammaRule(),
3129
EpsilonRule(),
3230
ZeroRule(),
3331
ZeroRule(),
32+
ZeroRule(),
33+
ZeroRule(),
3434
]
3535

3636
analyzer = LRP(model, rules)
@@ -60,18 +60,27 @@ function modify_param!(::MyGammaRule, param)
6060
end
6161

6262
# We can directly use this rule to make an analyzer!
63-
analyzer = LRP(model, MyGammaRule())
63+
rules = [
64+
ZBoxRule(0.0f0, 1.0f0),
65+
EpsilonRule(),
66+
MyGammaRule(),
67+
EpsilonRule(),
68+
ZeroRule(),
69+
ZeroRule(),
70+
ZeroRule(),
71+
ZeroRule(),
72+
]
73+
analyzer = LRP(model, rules)
6474
heatmap(input, analyzer)
6575

66-
# We just implemented our own version of the ``γ``-rule in 4 lines of code!
67-
# The outputs match perfectly:
68-
analyzer = LRP(model, GammaRule())
69-
heatmap(input, analyzer)
76+
# We just implemented our own version of the ``γ``-rule in 4 lines of code.
77+
# The heatmap perfectly matches the previous one!
7078

7179
# If the layer doesn't use weights `layer.weight` and biases `layer.bias`,
7280
# ExplainableAI provides a lower-level variant of [`modify_param!`](@ref)
7381
# called [`modify_layer!`](@ref). This function is expected to take a layer
7482
# and return a new, modified layer.
83+
# To add compatibility checks between rule and layer types, extend [`check_compat`](@ref).
7584

7685
#md # !!! warning "Using modify_layer!"
7786
#md #
@@ -98,7 +107,7 @@ mylayer([1, 2, 3])
98107
# Let's append this layer to our model:
99108
model = Chain(model..., MyDoublingLayer())
100109

101-
# Creating an LRP analyzer, e.g. `LRPZero(model)`, will throw an `ArgumentError`
110+
# Creating an LRP analyzer, e.g. `LRP(model)`, will throw an `ArgumentError`
102111
# and print a summary of the model check in the REPL:
103112
# ```julia-repl
104113
# ┌───┬───────────────────────┬─────────────────┬────────────┬────────────────┐
@@ -144,7 +153,7 @@ model = Chain(model..., MyDoublingLayer())
144153
LRP_CONFIG.supports_layer(::MyDoublingLayer) = true
145154

146155
# Now we can create and run an analyzer without getting an error:
147-
analyzer = LRPZero(model)
156+
analyzer = LRP(model)
148157
heatmap(input, analyzer)
149158

150159
#md # !!! note "Registering functions"
@@ -163,7 +172,7 @@ model = Chain(Flux.flatten, Dense(784, 100, myrelu), Dense(100, 10))
163172
# Once again, creating an LRP analyzer for this model will throw an `ArgumentError`
164173
# and display the following model check summary:
165174
# ```julia-repl
166-
# julia> analyzer = LRPZero(model3)
175+
# julia> analyzer = LRP(model3)
167176
# ┌───┬─────────────────────────┬─────────────────┬────────────┬────────────────┐
168177
# │ │ Layer │ Layer supported │ Activation │ Act. supported │
169178
# ├───┼─────────────────────────┼─────────────────┼────────────┼────────────────┤
@@ -187,7 +196,7 @@ model = Chain(Flux.flatten, Dense(784, 100, myrelu), Dense(100, 10))
187196
LRP_CONFIG.supports_activation(::typeof(myrelu)) = true
188197

189198
# now the analyzer can be created without error:
190-
analyzer = LRPZero(model)
199+
analyzer = LRP(model)
191200

192201
# ## How it works internally
193202
# Internally, ExplainableAI dispatches to low level functions
@@ -248,6 +257,7 @@ analyzer = LRPZero(model)
248257
# compute ``c`` from the previous equation as a VJP, pulling back ``s_{k}=R_{k}/z_{k}``:
249258
# ```julia
250259
# function lrp!(Rₖ, rule, layer, aₖ, Rₖ₊₁)
260+
# check_compat(rule, layer)
251261
# reset! = get_layer_resetter(layer)
252262
# modify_layer!(rule, layer)
253263
# ãₖ₊₁, pullback = Zygote.pullback(layer, modify_input(rule, aₖ))
@@ -256,8 +266,8 @@ analyzer = LRPZero(model)
256266
# end
257267
# ```
258268
#
259-
# You can see how `modify_layer!`, `modify_input` and `modify_denominator` dispatch on the
260-
# rule and layer type. This is how we implemented our own `MyGammaRule`.
269+
# You can see how `check_compat`, `modify_layer!`, `modify_input` and `modify_denominator`
270+
# dispatch on the rule and layer type. This is how we implemented our own `MyGammaRule`.
261271
# Unknown layers that are registered in the `LRP_CONFIG` use this exact function.
262272

263273
# ### Specialized implementations

docs/src/api.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ modify_input
3636
modify_denominator
3737
modify_param!
3838
modify_layer!
39+
check_compat
3940
LRP_CONFIG.supports_layer
4041
LRP_CONFIG.supports_activation
4142
```

src/ExplainableAI.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ export AbstractXAIMethod
3737
export Gradient, InputTimesGradient
3838
export NoiseAugmentation, SmoothGrad
3939
export InterpolationAugmentation, IntegratedGradients
40-
export LRP, LRPZero, LRPEpsilon, LRPGamma
40+
export LRP
4141

4242
# LRP rules
4343
export AbstractLRPRule

src/flux.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,11 @@ function strip_softmax(l::Conv)
6565
end
6666

6767
has_weight_and_bias(layer) = hasproperty(layer, :weight) && hasproperty(layer, :bias)
68+
function require_weight_and_bias(rule, layer)
69+
!has_weight_and_bias(layer) && throw(
70+
ArgumentError(
71+
"$rule requires linear layer with weight and bias parameters, got $layer."
72+
),
73+
)
74+
return nothing
75+
end

src/lrp.jl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,8 @@ function LRP(model::Chain, r::AbstractLRPRule; kwargs...)
4141
rules = repeat([r], length(model.layers))
4242
return LRP(model, rules; kwargs...)
4343
end
44-
# Additional constructors for convenience:
44+
# Additional constructors for convenience: use ZeroRule everywhere
4545
LRP(model::Chain; kwargs...) = LRP(model, ZeroRule(); kwargs...)
46-
LRPZero(model::Chain; kwargs...) = LRP(model, ZeroRule(); kwargs...)
47-
LRPEpsilon(model::Chain; kwargs...) = LRP(model, EpsilonRule(); kwargs...)
48-
LRPGamma(model::Chain; kwargs...) = LRP(model, GammaRule(); kwargs...)
4946

5047
# The call to the LRP analyzer.
5148
function (analyzer::LRP)(

src/lrp_rules.jl

Lines changed: 61 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
# https://adrhill.github.io/ExplainableAI.jl/stable/generated/advanced_lrp/#How-it-works-internally
22
abstract type AbstractLRPRule end
33

4-
# TODO: support all linear layers that use properties `weight` and `bias`
5-
const WeightBiasLayers = (Dense, Conv)
6-
74
# Generic LRP rule. Since it uses autodiff, it is used as a fallback for layer types
85
# without custom implementations.
96
function lrp!(Rₖ, rule::R, layer::L, aₖ, Rₖ₊₁) where {R<:AbstractLRPRule,L}
7+
check_compat(rule, layer)
108
reset! = get_layer_resetter(rule, layer)
119
modify_layer!(rule, layer)
1210
ãₖ₊₁, pullback = Zygote.pullback(layer, modify_input(rule, aₖ))
@@ -18,6 +16,7 @@ end
1816
# To implement new rules, define the following custom functions:
1917
# * `modify_input(rule, input)`
2018
# * `modify_denominator(rule, d)`
19+
# * `check_compat(rule, layer)`
2120
# * `modify_param!(rule, param)` or `modify_layer!(rule, layer)`,
2221
# the latter overriding the former
2322
#
@@ -36,6 +35,17 @@ Modify denominator ``z`` for numerical stability on the forward pass.
3635
"""
3736
@inline modify_denominator(rule, d) = stabilize_denom(d, 1.0f-9) # general fallback
3837

38+
"""
39+
check_compat(rule, layer)
40+
41+
Check compatibility of a LRP-Rule with layer type.
42+
43+
## Note
44+
When implementing a custom `check_compat` function, return `nothing` if checks passed,
45+
otherwise throw an `ArgumentError`.
46+
"""
47+
@inline check_compat(rule, layer) = require_weight_and_bias(rule, layer)
48+
3949
"""
4050
modify_layer!(rule, layer)
4151
@@ -45,15 +55,12 @@ propagation.
4555
## Note
4656
When implementing a custom `modify_layer!` function, `modify_param!` will not be called.
4757
"""
48-
modify_layer!(rule, layer) = nothing
49-
for L in WeightBiasLayers
50-
@eval function modify_layer!(rule::R, layer::$L) where {R}
51-
if has_weight_and_bias(layer)
52-
modify_param!(rule, layer.weight)
53-
modify_bias!(rule, layer.bias)
54-
end
55-
return nothing
58+
function modify_layer!(rule::R, layer::L) where {R,L}
59+
if has_weight_and_bias(layer)
60+
modify_param!(rule, layer.weight)
61+
modify_bias!(rule, layer.bias)
5662
end
63+
return nothing
5764
end
5865

5966
"""
@@ -97,6 +104,22 @@ end
97104
Constructor for LRP-0 rule. Commonly used on upper layers.
98105
"""
99106
struct ZeroRule <: AbstractLRPRule end
107+
@inline check_compat(::ZeroRule, layer) = nothing
108+
109+
"""
110+
EpsilonRule([ϵ=1.0f-6])
111+
112+
Constructor for LRP-``ϵ`` rule. Commonly used on middle layers.
113+
114+
Arguments:
115+
- `ϵ`: Optional stabilization parameter, defaults to `1f-6`.
116+
"""
117+
struct EpsilonRule{T} <: AbstractLRPRule
118+
ϵ::T
119+
EpsilonRule=1.0f-6) = new{Float32}(ϵ)
120+
end
121+
modify_denominator(r::EpsilonRule, d) = stabilize_denom(d, r.ϵ)
122+
@inline check_compat(::EpsilonRule, layer) = nothing
100123

101124
"""
102125
GammaRule([γ=0.25])
@@ -115,20 +138,7 @@ function modify_param!(r::GammaRule, param::AbstractArray{T}) where {T}
115138
param .+= γ * relu.(param)
116139
return nothing
117140
end
118-
119-
"""
120-
EpsilonRule([ϵ=1.0f-6])
121-
122-
Constructor for LRP-``ϵ`` rule. Commonly used on middle layers.
123-
124-
Arguments:
125-
- `ϵ`: Optional stabilization parameter, defaults to `1f-6`.
126-
"""
127-
struct EpsilonRule{T} <: AbstractLRPRule
128-
ϵ::T
129-
EpsilonRule=1.0f-6) = new{Float32}(ϵ)
130-
end
131-
modify_denominator(r::EpsilonRule, d) = stabilize_denom(d, r.ϵ)
141+
@inline check_compat(rule::GammaRule, layer) = require_weight_and_bias(rule, layer)
132142

133143
"""
134144
ZBoxRule(low, high)
@@ -146,45 +156,44 @@ struct ZBoxRule{T} <: AbstractLRPRule
146156
end
147157

148158
# The ZBoxRule requires its own implementation of relevance propagation.
149-
for L in WeightBiasLayers
150-
function lrp!(Rₖ, rule::ZBoxRule, layer, aₖ, Rₖ₊₁)
151-
T = eltype(aₖ)
152-
l = zbox_input_augmentation(T, rule.low, size(aₖ))
153-
h = zbox_input_augmentation(T, rule.high, size(aₖ))
154-
reset! = get_layer_resetter(rule, layer)
159+
function lrp!(Rₖ, rule::ZBoxRule, layer::L, aₖ, Rₖ₊₁) where {L}
160+
require_weight_and_bias(rule, layer)
161+
reset! = get_layer_resetter(rule, layer)
155162

156-
# Compute pullback for W, b
157-
aₖ₊₁, pullback = Zygote.pullback(layer, aₖ)
163+
l = zbox_input(aₖ, rule.low)
164+
h = zbox_input(aₖ, rule.high)
158165

159-
# Compute pullback for W⁺, b⁺
160-
modify_layer!(Val{:mask_positive}, layer)
161-
aₖ₊₁⁺, pullback⁺ = Zygote.pullback(layer, l)
162-
reset!()
166+
# Compute pullback for W, b
167+
aₖ₊₁, pullback = Zygote.pullback(layer, aₖ)
163168

164-
# Compute pullback for W, b
165-
modify_layer!(Val{:mask_negative}, layer)
166-
aₖ₊₁, pullback = Zygote.pullback(layer, h)
167-
reset!()
169+
# Compute pullback for W, b
170+
modify_layer!(Val{:mask_positive}, layer)
171+
aₖ₊₁, pullback = Zygote.pullback(layer, l)
172+
reset!()
168173

169-
y = Rₖ₊₁ ./ modify_denominator(rule, aₖ₊₁ - aₖ₊₁⁺ - aₖ₊₁⁻)
170-
Rₖ .= aₖ .* only(pullback(y)) - l .* only(pullback⁺(y)) - h .* only(pullback⁻(y))
171-
return nothing
172-
end
174+
# Compute pullback for W⁻, b⁻
175+
modify_layer!(Val{:mask_negative}, layer)
176+
aₖ₊₁⁻, pullback⁻ = Zygote.pullback(layer, h)
177+
reset!()
178+
179+
y = Rₖ₊₁ ./ modify_denominator(rule, aₖ₊₁ - aₖ₊₁⁺ - aₖ₊₁⁻)
180+
Rₖ .= aₖ .* only(pullback(y)) - l .* only(pullback⁺(y)) - h .* only(pullback⁻(y))
181+
return nothing
173182
end
174183

175-
const ZBOX_BOUNDS_MISMATCH = "ZBoxRule bounds should either be scalar or match input size."
176-
function zbox_input_augmentation(T, A::AbstractArray, in_size)
177-
size(A) != in_size && throw(ArgumentError(ZBOX_BOUNDS_MISMATCH))
184+
zbox_input(in::AbstractArray{T}, c::Real) where {T} = fill(convert(T, c), size(in))
185+
function zbox_input(in::AbstractArray{T}, A::AbstractArray) where {T}
186+
@assert size(A) == size(in)
178187
return convert.(T, A)
179188
end
180-
zbox_input_augmentation(T, c::Real, in_size) = fill(convert(T, c), in_size)
181189

182-
# Other special cases that are dispatched on layer type:
183-
const LRPRules = (ZeroRule, EpsilonRule, GammaRule, ZBoxRule)
184-
for R in LRPRules
190+
# Special cases for rules that don't modify params for extra performance:
191+
for R in (ZeroRule, EpsilonRule)
192+
@eval get_layer_resetter(::$R, l) = Returns(nothing)
185193
@eval lrp!(Rₖ, ::$R, ::DropoutLayer, aₖ, Rₖ₊₁) = (Rₖ .= Rₖ₊₁)
186194
@eval lrp!(Rₖ, ::$R, ::ReshapingLayer, aₖ, Rₖ₊₁) = (Rₖ .= reshape(Rₖ₊₁, size(aₖ)))
187195
end
196+
188197
# Fast implementation for Dense layer using Tullio.jl's einsum notation:
189198
for R in (ZeroRule, EpsilonRule, GammaRule)
190199
@eval function lrp!(Rₖ, rule::$R, layer::Dense, aₖ, Rₖ₊₁)
@@ -196,7 +205,3 @@ for R in (ZeroRule, EpsilonRule, GammaRule)
196205
return nothing
197206
end
198207
end
199-
200-
# Rules that don't modify params can optionally be added here for extra performance
201-
get_layer_resetter(::ZeroRule, l) = Returns(nothing)
202-
get_layer_resetter(::EpsilonRule, l) = Returns(nothing)

0 commit comments

Comments
 (0)