Skip to content

Commit 239deae

Browse files
authored
Fix bug in LRP rules (#92)
* Fix bug in `AlphaBetaRule`,`ZPlusRule`, `ZBoxRule` and `FlatRule` * Add `preactivation` function * Add analytic tests for `AlphaBetaRule` and `ZPlusRule` * Update references
1 parent bc06035 commit 239deae

32 files changed

+122
-44
lines changed

src/flux_utils.jl

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@ activation(l::Conv) = l.σ
99
activation(l::BatchNorm) = l.λ
1010
activation(layer) = nothing # default for all other layer types
1111

12+
function has_activation(layer)
13+
hasproperty(layer, ) && return true
14+
hasproperty(layer, ) && return true
15+
return false
16+
end
17+
1218
"""
1319
flatten_model(c)
1420
@@ -73,3 +79,32 @@ function require_weight_and_bias(rule, layer)
7379
)
7480
return nothing
7581
end
82+
83+
# LRP requires computing so called pre-activations `z`.
84+
# These correspond to calling a layer without applying its activation function.
85+
preactivation(layer) = x -> preactivation(layer, x)
86+
function preactivation(d::Dense, x::AbstractVecOrMat)
87+
return d.weight * x .+ d.bias
88+
end
89+
function preactivation(d::Dense, x::AbstractArray)
90+
return reshape(d(reshape(x, size(x, 1), :)), :, size(x)[2:end]...)
91+
end
92+
function preactivation(c::Conv, x)
93+
cdims = Flux.conv_dims(c, x)
94+
return Flux.conv(x, c.weight, cdims) .+ Flux.conv_reshape_bias(c)
95+
end
96+
97+
function preactivation(c::ConvTranspose, x)
98+
cdims = Flux.conv_transpose_dims(c, x)
99+
return Flux.∇conv_data(x, c.weight, cdims) .+ Flux.conv_reshape_bias(c)
100+
end
101+
function preactivation(c::CrossCor, x)
102+
cdims = Flux.crosscor_dims(c, x)
103+
return Flux.crosscor(x, c.weight, cdims) .+ Flux.conv_reshape_bias(c)
104+
end
105+
function preactivation(l, x)
106+
has_activation(l) &&
107+
error("""Layer $l contains an activation function and therefore requires an
108+
implementation of `preactivation(layer, input)`""")
109+
return l(x)
110+
end

src/lrp/rules.jl

Lines changed: 50 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ function lrp!(Rₖ, rule::R, layer::L, aₖ, Rₖ₊₁) where {R<:AbstractLRPRu
1313
check_compat(rule, layer)
1414
reset! = get_layer_resetter(rule, layer)
1515
modify_layer!(rule, layer)
16-
ãₖ₊₁, pullback = Zygote.pullback(layer, modify_input(rule, aₖ))
16+
ãₖ₊₁, pullback = Zygote.pullback(preactivation(layer), modify_input(rule, aₖ))
1717
Rₖ .= aₖ .* only(pullback(Rₖ₊₁ ./ modify_denominator(rule, ãₖ₊₁)))
1818
reset!()
1919
return nothing
@@ -180,13 +180,15 @@ modify_input(::WSquareRule, input) = ones_like(input)
180180
"""
181181
FlatRule()
182182
183-
LRP-Flat rule. Similar to the [`WSquareRule`](@ref), but with all parameters set to one.
183+
LRP-Flat rule. Similar to the [`WSquareRule`](@ref), but with all weights set to one
184+
and all bias terms set to zero.
184185
185186
# References
186187
- $REF_LAPUSCHKIN_CLEVER_HANS
187188
"""
188189
struct FlatRule <: AbstractLRPRule end
189-
modify_param!(::FlatRule, p) = fill!(p, 1)
190+
modify_weight!(::FlatRule, w) = fill!(w, 1)
191+
modify_bias!(::FlatRule, b) = fill!(b, 0)
190192
modify_input(::FlatRule, input) = ones_like(input)
191193

192194
"""
@@ -232,20 +234,25 @@ function lrp!(Rₖ, rule::ZBoxRule, layer::L, aₖ, Rₖ₊₁) where {L}
232234
h = zbox_input(aₖ, rule.high)
233235

234236
# Compute pullback for W, b
235-
aₖ₊₁, pullback = Zygote.pullback(layer, aₖ)
237+
aₖ₊₁, pullback = Zygote.pullback(preactivation(layer), aₖ)
236238

237239
# Compute pullback for W⁺, b⁺
238240
modify_layer!(Val(:keep_positive), layer)
239-
aₖ₊₁⁺, pullback⁺ = Zygote.pullback(layer, l)
241+
aₖ₊₁⁺, pullback⁺ = Zygote.pullback(preactivation(layer), l)
240242
reset!()
241243

242244
# Compute pullback for W⁻, b⁻
243245
modify_layer!(Val(:keep_negative), layer)
244-
aₖ₊₁⁻, pullback⁻ = Zygote.pullback(layer, h)
245-
reset!()
246+
aₖ₊₁⁻, pullback⁻ = Zygote.pullback(preactivation(layer), h)
246247

248+
# Evaluate pullbacks
247249
y = Rₖ₊₁ ./ modify_denominator(rule, aₖ₊₁ - aₖ₊₁⁺ - aₖ₊₁⁻)
248-
Rₖ .= aₖ .* only(pullback(y)) - l .* only(pullback⁺(y)) - h .* only(pullback⁻(y))
250+
Rₖ .= -h .* only(pullback⁻(y))
251+
reset!() # re-modify mutated pullback
252+
Rₖ .+= aₖ .* only(pullback(y))
253+
modify_layer!(Val(:keep_positive), layer) # re-modify mutated pullback
254+
Rₖ .-= l .* only(pullback⁺(y))
255+
reset!()
249256
return nothing
250257
end
251258

@@ -278,7 +285,7 @@ struct AlphaBetaRule{T} <: AbstractLRPRule
278285
alpha < 0 && throw(ArgumentError("Parameter `alpha` must be ≥0."))
279286
beta < 0 && throw(ArgumentError("Parameter `beta` must be ≥0."))
280287
!isone(alpha - beta) && throw(ArgumentError("`alpha - beta` must be equal one."))
281-
return new{Float32}(alpha, beta)
288+
return new{eltype(alpha)}(alpha, beta)
282289
end
283290
end
284291

@@ -290,24 +297,33 @@ function lrp!(Rₖ, rule::AlphaBetaRule, layer::L, aₖ, Rₖ₊₁) where {L}
290297
aₖ⁺ = keep_positive(aₖ)
291298
aₖ⁻ = keep_negative(aₖ)
292299

300+
# α: positive contributions
301+
modify_layer!(Val(:keep_negative_zero_bias), layer)
302+
aₖ₊₁ᵅ⁻, pullbackᵅ⁻ = Zygote.pullback(preactivation(layer), aₖ⁻)
303+
reset!()
293304
modify_layer!(Val(:keep_positive), layer)
294-
out_1, pullback_1 = Zygote.pullback(layer, aₖ⁺)
305+
aₖ₊₁ᵅ⁺, pullbackᵅ⁺ = Zygote.pullback(preactivation(layer), aₖ⁺)
306+
# evaluate pullbacks
307+
yᵅ = Rₖ₊₁ ./ modify_denominator(rule, aₖ₊₁ᵅ⁺ + aₖ₊₁ᵅ⁻)
308+
Rₖ .= rule.α .* aₖ⁺ .* only(pullbackᵅ⁺(yᵅ))
295309
reset!()
296-
modify_layer!(Val(:keep_negative_zero_bias), layer)
297-
out_2, pullback_2 = Zygote.pullback(layer, aₖ⁻)
310+
modify_layer!(Val(:keep_negative_zero_bias), layer) # re-modify mutated pullback
311+
Rₖ .+= rule.α .* aₖ⁻ .* only(pullbackᵅ⁻(yᵅ))
312+
reset!()
313+
314+
# β: Negative contributions
315+
modify_layer!(Val(:keep_positive_zero_bias), layer)
316+
aₖ₊₁ᵝ⁻, pullbackᵝ⁻ = Zygote.pullback(preactivation(layer), aₖ⁻) #
298317
reset!()
299318
modify_layer!(Val(:keep_negative), layer)
300-
out_3, pullback_3 = Zygote.pullback(layer, aₖ⁺)
319+
aₖ₊₁ᵝ⁺, pullbackᵝ⁺ = Zygote.pullback(preactivation(layer), aₖ⁺)
320+
# evaluate pullbacks
321+
yᵝ = Rₖ₊₁ ./ modify_denominator(rule, aₖ₊₁ᵝ⁺ + aₖ₊₁ᵝ⁻)
322+
Rₖ .-= rule.β .* aₖ⁺ .* only(pullbackᵝ⁺(yᵝ))
301323
reset!()
302-
modify_layer!(Val(:keep_positive_zero_bias), layer)
303-
out_4, pullback_4 = Zygote.pullback(layer, aₖ⁻)
324+
modify_layer!(Val(:keep_positive_zero_bias), layer) # re-modify mutated pullback
325+
Rₖ .-= rule.β .* aₖ⁻ .* only(pullbackᵝ⁻(yᵝ))
304326
reset!()
305-
306-
y_α = Rₖ₊₁ ./ modify_denominator(rule, out_1 + out_2)
307-
y_β = Rₖ₊₁ ./ modify_denominator(rule, out_3 + out_4)
308-
Rₖ .=
309-
rule.α .* (aₖ⁺ .* only(pullback_1(y_α)) + aₖ⁻ .* only(pullback_2(y_α))) .-
310-
rule.β .* (aₖ⁺ .* only(pullback_3(y_β)) + aₖ⁻ .* only(pullback_4(y_β)))
311327
return nothing
312328
end
313329

@@ -331,15 +347,20 @@ function lrp!(Rₖ, rule::ZPlusRule, layer::L, aₖ, Rₖ₊₁) where {L}
331347
aₖ⁺ = keep_positive(aₖ)
332348
aₖ⁻ = keep_negative(aₖ)
333349

350+
# Linearize around positive & negative activations (aₖ⁺, aₖ⁻)
334351
modify_layer!(Val(:keep_positive), layer)
335-
out_1, pullback_1 = Zygote.pullback(layer, aₖ⁺)
352+
aₖ₊₁⁺, pullback⁺ = Zygote.pullback(layer, aₖ⁺)
336353
reset!()
337354
modify_layer!(Val(:keep_negative_zero_bias), layer)
338-
out_2, pullback_2 = Zygote.pullback(layer, aₖ⁻)
339-
reset!()
355+
aₖ₊₁⁻, pullback⁻ = Zygote.pullback(layer, aₖ⁻)
340356

341-
y_α = Rₖ₊₁ ./ modify_denominator(rule, out_1 + out_2)
342-
Rₖ .= aₖ⁺ .* only(pullback_1(y_α)) + aₖ⁻ .* only(pullback_2(y_α))
357+
# Evaluate pullbacks
358+
y = Rₖ₊₁ ./ modify_denominator(rule, aₖ₊₁⁺ + aₖ₊₁⁻)
359+
Rₖ .= aₖ⁻ .* only(pullback⁻(y))
360+
reset!()
361+
modify_layer!(Val(:keep_positive), layer) # re-modify mutated pullback
362+
Rₖ .+= aₖ⁺ .* only(pullback⁺(y))
363+
reset!()
343364
return nothing
344365
end
345366

@@ -355,8 +376,9 @@ for R in (ZeroRule, EpsilonRule, GammaRule, WSquareRule, FlatRule)
355376
@eval function lrp!(Rₖ, rule::$R, layer::Dense, aₖ, Rₖ₊₁)
356377
reset! = get_layer_resetter(rule, layer)
357378
modify_layer!(rule, layer)
358-
ãₖ₊₁ = modify_denominator(rule, layer(modify_input(rule, aₖ)))
359-
@tullio Rₖ[j, b] = aₖ[j, b] * layer.weight[k, j] * Rₖ₊₁[k, b] / ãₖ₊₁[k, b]
379+
ãₖ = modify_input(rule, aₖ)
380+
ãₖ₊₁ = modify_denominator(rule, preactivation(layer, ãₖ))
381+
@tullio Rₖ[j, b] = ãₖ[j, b] * layer.weight[k, j] * Rₖ₊₁[k, b] / ãₖ₊₁[k, b]
360382
reset!()
361383
return nothing
362384
end

0 commit comments

Comments
 (0)