Skip to content

Commit 16119c4

Browse files
authored
Refactor lrp! (#70)
* Make `Rₖ₊₁` first argument of `lrp!`, matching Bang convention * Change rule keyword-arguments to default arguments * Type stability fixes for `GammaRule` * Add explanations for non-Julia programmers to docs
1 parent dc4d0fc commit 16119c4

File tree

5 files changed

+45
-40
lines changed

5 files changed

+45
-40
lines changed

benchmark/benchmarks.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ struct TestWrapper{T}
5252
end
5353
(w::TestWrapper)(x) = w.layer(x)
5454
modify_layer(r::AbstractLRPRule, w::TestWrapper) = modify_layer(r, w.layer)
55-
lrp!(rule::ZBoxRule, w::TestWrapper, Rₖ, aₖ, Rₖ₊₁) = lrp!(rule, w.layer, Rₖ, aₖ, Rₖ₊₁)
55+
lrp!(Rₖ, rule::ZBoxRule, w::TestWrapper, aₖ, Rₖ₊₁) = lrp!(Rₖ, rule, w.layer, aₖ, Rₖ₊₁)
5656

5757
# generate input for conv layers
5858
insize = (64, 64, 3, 1)
@@ -81,7 +81,7 @@ for (layername, (layer, aₖ)) in layers
8181
Rₖ₊₁ = layer(aₖ)
8282
for (rulename, rule) in rules
8383
SUITE["Layer"][layername][rulename] = @benchmarkable lrp!(
84-
$(rule), $(layer), $(Rₖ), $(aₖ), $(Rₖ₊₁)
84+
$(Rₖ), $(rule), $(layer), $(aₖ), $(Rₖ₊₁)
8585
)
8686
end
8787
end

docs/literate/advanced_lrp.jl

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -191,16 +191,20 @@ analyzer = LRPZero(model)
191191
# ## How it works internally
192192
# Internally, ExplainableAI dispatches to low level functions
193193
# ```julia
194-
# function lrp!(rule, layer, Rₖ, aₖ, Rₖ₊₁)
194+
# lrp!(Rₖ, rule, layer, aₖ, Rₖ₊₁)
195195
# Rₖ .= ...
196196
# end
197197
# ```
198-
# These functions use the arguments `rule` and `layer` to dispatch
199-
# `modify_params` and `modify_denominator` on the rule and layer type.
200-
# They in-place modify a pre-allocated array of the input relevance `Rₖ`
201-
# based on the input activation `aₖ` and output relevance `Rₖ₊₁`.
198+
# These functions in-place modify a pre-allocated array of the input relevance `Rₖ`
199+
# (the `!` is a [naming convention](https://docs.julialang.org/en/v1/manual/style-guide/#bang-convention)
200+
# in Julia to denote functions that modify their arguments).
201+
202+
# The correct rule is applied via [multiple dispatch](https://www.youtube.com/watch?v=kc9HwsxE1OY)
203+
# on the types of the arguments `rule` and `layer`.
204+
# The relevance `Rₖ` is then computed based on the input activation `aₖ` and the output relevance `Rₖ₊₁`.
205+
# Multiple dispatch is also used to dispatch `modify_params` and `modify_denominator` on the rule and layer type.
202206
#
203-
# Calling `analyze` then applies a forward-pass of the model, keeping track of
207+
# Calling `analyze` on a LRP-model applies a forward-pass of the model, keeping track of
204208
# the activations `aₖ` for each layer `k`.
205209
# The relevance `Rₖ₊₁` is then set to the output neuron activation and the rules are applied
206210
# in a backward-pass over the model layers and previous activations.
@@ -241,7 +245,7 @@ analyzer = LRPZero(model)
241245
# The default LRP fallback for unknown layers uses AD via [Zygote](https://github.com/FluxML/Zygote.jl).
242246
# For `lrp!`, we end up with something that looks very similar to the previous four step computation:
243247
# ```julia
244-
# function lrp!(rule, layer, Rₖ, aₖ, Rₖ₊₁)
248+
# function lrp!(Rₖ, rule, layer, aₖ, Rₖ₊₁)
245249
# layerᵨ = modify_layer(rule, layer)
246250
# c = gradient(aₖ) do a
247251
# z = layerᵨ(a)
@@ -263,7 +267,7 @@ analyzer = LRPZero(model)
263267
# Reshaping layers don't affect attributions. We can therefore avoid the computational
264268
# overhead of AD by writing a specialized implementation that simply reshapes back:
265269
# ```julia
266-
# function lrp!(::AbstractLRPRule, ::ReshapingLayer, Rₖ, aₖ, Rₖ₊₁)
270+
# function lrp!(Rₖ, ::AbstractLRPRule, ::ReshapingLayer, aₖ, Rₖ₊₁)
267271
# Rₖ .= reshape(Rₖ₊₁, size(aₖ))
268272
# end
269273
# ```
@@ -272,7 +276,7 @@ analyzer = LRPZero(model)
272276
#
273277
# We can even implement the generic rule as a specialized implementation for `Dense` layers:
274278
# ```julia
275-
# function lrp!(rule::AbstractLRPRule, layer::Dense, Rₖ, aₖ, Rₖ₊₁)
279+
# function lrp!(Rₖ, rule::AbstractLRPRule, layer::Dense, aₖ, Rₖ₊₁)
276280
# ρW, ρb = modify_params(rule, get_params(layer)...)
277281
# ãₖ₊₁ = modify_denominator(rule, ρW * aₖ + ρb)
278282
# @tullio Rₖ[j] = aₖ[j] * ρW[k, j] / ãₖ₊₁[k] * Rₖ₊₁[k] # Tullio ≈ fast einsum
@@ -283,7 +287,7 @@ analyzer = LRPZero(model)
283287
# you can also implement your own `lrp!` function and dispatch
284288
# on individual rule types `MyRule` and layer types `MyLayer`:
285289
# ```julia
286-
# function lrp!(rule::MyRule, layer::MyLayer, Rₖ, aₖ, Rₖ₊₁)
290+
# function lrp!(Rₖ, rule::MyRule, layer::MyLayer, aₖ, Rₖ₊₁)
287291
# Rₖ .= ...
288292
# end
289293
# ```

src/lrp.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ function (analyzer::LRP)(
6565

6666
# Backward pass through layers, applying LRP rules
6767
for (i, rule) in Iterators.reverse(enumerate(analyzer.rules))
68-
lrp!(rule, layers[i], rels[i], acts[i], rels[i + 1]) # inplace update rels[i]
68+
lrp!(rels[i], rule, layers[i], acts[i], rels[i + 1]) # inplace update rels[i]
6969
end
7070

7171
return Explanation(

src/lrp_rules.jl

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22
abstract type AbstractLRPRule end
33

44
# Generic LRP rule. Since it uses autodiff, it is used as a fallback for layer types without custom implementation.
5-
function lrp!(rule::R, layer::L, Rₖ, aₖ, Rₖ₊₁) where {R<:AbstractLRPRule,L}
6-
lrp_autodiff!(rule, layer, Rₖ, aₖ, Rₖ₊₁)
5+
function lrp!(Rₖ, rule::R, layer::L, aₖ, Rₖ₊₁) where {R<:AbstractLRPRule,L}
6+
lrp_autodiff!(Rₖ, rule, layer, aₖ, Rₖ₊₁)
77
return nothing
88
end
99

1010
function lrp_autodiff!(
11-
rule::R, layer::L, Rₖ::T1, aₖ::T1, Rₖ₊₁::T2
11+
Rₖ::T1, rule::R, layer::L, aₖ::T1, Rₖ₊₁::T2
1212
) where {R<:AbstractLRPRule,L,T1,T2}
1313
layerᵨ = modify_layer(rule, layer)
1414
c::T1 = only(
@@ -23,21 +23,21 @@ function lrp_autodiff!(
2323
end
2424

2525
# For linear layer types such as Dense layers, using autodiff is overkill.
26-
function lrp!(rule::R, layer::Dense, Rₖ, aₖ, Rₖ₊₁) where {R<:AbstractLRPRule}
27-
lrp_dense!(rule, layer, Rₖ, aₖ, Rₖ₊₁)
26+
function lrp!(Rₖ, rule::R, layer::Dense, aₖ, Rₖ₊₁) where {R<:AbstractLRPRule}
27+
lrp_dense!(Rₖ, rule, layer, aₖ, Rₖ₊₁)
2828
return nothing
2929
end
3030

31-
function lrp_dense!(rule::R, l, Rₖ, aₖ, Rₖ₊₁) where {R<:AbstractLRPRule}
31+
function lrp_dense!(Rₖ, rule::R, l, aₖ, Rₖ₊₁) where {R<:AbstractLRPRule}
3232
ρW, ρb = modify_params(rule, get_params(l)...)
3333
ãₖ₊₁ = modify_denominator(rule, ρW * aₖ .+ ρb)
3434
@tullio Rₖ[j, b] = aₖ[j, b] * ρW[k, j] / ãₖ₊₁[k, b] * Rₖ₊₁[k, b]
3535
return nothing
3636
end
3737

3838
# Other special cases that are dispatched on layer type:
39-
lrp!(::AbstractLRPRule, ::DropoutLayer, Rₖ, aₖ, Rₖ₊₁) = (Rₖ .= Rₖ₊₁)
40-
lrp!(::AbstractLRPRule, ::ReshapingLayer, Rₖ, aₖ, Rₖ₊₁) = (Rₖ .= reshape(Rₖ₊₁, size(aₖ)))
39+
lrp!(Rₖ, ::AbstractLRPRule, ::DropoutLayer, aₖ, Rₖ₊₁) = (Rₖ .= Rₖ₊₁)
40+
lrp!(Rₖ, ::AbstractLRPRule, ::ReshapingLayer, aₖ, Rₖ₊₁) = (Rₖ .= reshape(Rₖ₊₁, size(aₖ)))
4141

4242
# To implement new rules, we can define two custom functions `modify_params` and `modify_denominator`.
4343
# If this isn't done, the following fallbacks are used by default:
@@ -75,7 +75,7 @@ Constructor for LRP-0 rule. Commonly used on upper layers.
7575
struct ZeroRule <: AbstractLRPRule end
7676

7777
"""
78-
GammaRule(; γ=0.25)
78+
GammaRule([γ=0.25])
7979
8080
Constructor for LRP-``γ`` rule. Commonly used on lower layers.
8181
@@ -84,16 +84,17 @@ Arguments:
8484
"""
8585
struct GammaRule{T} <: AbstractLRPRule
8686
γ::T
87-
GammaRule(; γ=0.25) = new{Float32}(γ)
87+
GammaRule=0.25f0) = new{Float32}(γ)
8888
end
8989
function modify_params(r::GammaRule, W, b)
90-
ρW = W + r.γ * relu.(W)
91-
ρb = b + r.γ * relu.(b)
90+
T = eltype(W)
91+
ρW = W + convert(T, r.γ) * relu.(W)
92+
ρb = b + convert(T, r.γ) * relu.(b)
9293
return ρW, ρb
9394
end
9495

9596
"""
96-
EpsilonRule(; ϵ=1f-6)
97+
EpsilonRule([ϵ=1.0f-6])
9798
9899
Constructor for LRP-``ϵ`` rule. Commonly used on middle layers.
99100
@@ -102,7 +103,7 @@ Arguments:
102103
"""
103104
struct EpsilonRule{T} <: AbstractLRPRule
104105
ϵ::T
105-
EpsilonRule(; ϵ=1.0f-6) = new{Float32}(ϵ)
106+
EpsilonRule=1.0f-6) = new{Float32}(ϵ)
106107
end
107108
modify_denominator(r::EpsilonRule, d) = stabilize_denom(d, r.ϵ)
108109

@@ -122,8 +123,8 @@ struct ZBoxRule{T} <: AbstractLRPRule
122123
end
123124

124125
# The ZBoxRule requires its own implementation of relevance propagation.
125-
lrp!(r::ZBoxRule, layer::Dense, Rₖ, aₖ, Rₖ₊₁) = lrp_zbox!(r, layer, Rₖ, aₖ, Rₖ₊₁)
126-
lrp!(r::ZBoxRule, layer::Conv, Rₖ, aₖ, Rₖ₊₁) = lrp_zbox!(r, layer, Rₖ, aₖ, Rₖ₊₁)
126+
lrp!(Rₖ, r::ZBoxRule, layer::Dense, aₖ, Rₖ₊₁) = lrp_zbox!(Rₖ, r, layer, aₖ, Rₖ₊₁)
127+
lrp!(Rₖ, r::ZBoxRule, layer::Conv, aₖ, Rₖ₊₁) = lrp_zbox!(Rₖ, r, layer, aₖ, Rₖ₊₁)
127128

128129
_zbox_bound(T, c::Real, in_size) = fill(convert(T, c), in_size)
129130
function _zbox_bound(T, A::AbstractArray, in_size)
@@ -135,7 +136,7 @@ function _zbox_bound(T, A::AbstractArray, in_size)
135136
return convert.(T, A)
136137
end
137138

138-
function lrp_zbox!(r::ZBoxRule, layer::L, Rₖ::T1, aₖ::T1, Rₖ₊₁::T2) where {L,T1,T2}
139+
function lrp_zbox!(Rₖ::T1, r::ZBoxRule, layer::L, aₖ::T1, Rₖ₊₁::T2) where {L,T1,T2}
139140
T = eltype(aₖ)
140141
in_size = size(aₖ)
141142
l = _zbox_bound(T, r.low, in_size)

test/test_rules.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ const RULES = Dict(
2727

2828
layer = Dense(W, b, relu)
2929
R̂ₖ = similar(aₖ) # will be inplace updated
30-
@inferred lrp!(rule, layer, R̂ₖ, aₖ, Rₖ₊₁)
30+
@inferred lrp!(R̂ₖ, rule, layer, aₖ, Rₖ₊₁)
3131
@test R̂ₖ Rₖ
3232

3333
## Pooling layer
@@ -42,7 +42,7 @@ const RULES = Dict(
4242

4343
layer = MaxPool((2, 2); stride=(1, 1))
4444
R̂ₖ = similar(aₖ) # will be inplace updated
45-
@inferred lrp!(rule, layer, R̂ₖ, aₖ, Rₖ₊₁)
45+
@inferred lrp!(R̂ₖ, rule, layer, aₖ, Rₖ₊₁)
4646
@test R̂ₖ Rₖ
4747
end
4848

@@ -53,7 +53,7 @@ pseudorandn(dims...) = randn(MersenneTwister(123), T, dims...)
5353
## Test individual rules
5454
@testset "modify_params" begin
5555
W, b = [1.0 -1.0; 2.0 0.0], [-1.0, 1.0]
56-
ρW, ρb = @inferred modify_params(GammaRule(; γ=0.42), W, b)
56+
ρW, ρb = @inferred modify_params(GammaRule(0.42), W, b)
5757
@test ρW [1.42 -1.0; 2.84 0.0]
5858
@test ρb [-1.0, 1.42]
5959
end
@@ -67,7 +67,7 @@ aₖ_dense = pseudorandn(ins_dense, batchsize)
6767

6868
layers = Dict(
6969
"Dense_relu" => Dense(ins_dense, outs_dense, relu; init=pseudorandn),
70-
"Dense_identity" => Dense(Matrix(I, outs_dense, ins_dense), false, identity),
70+
"Dense_identity" => Dense(Matrix{Float32}(I, outs_dense, ins_dense), false, identity),
7171
)
7272
@testset "Dense" begin
7373
for (rulename, rule) in RULES
@@ -76,7 +76,7 @@ layers = Dict(
7676
@testset "$layername" begin
7777
Rₖ₊₁ = layer(aₖ_dense)
7878
Rₖ = similar(aₖ_dense)
79-
@inferred lrp!(rule, layer, Rₖ, aₖ_dense, Rₖ₊₁)
79+
@inferred lrp!(Rₖ, rule, layer, aₖ_dense, Rₖ₊₁)
8080

8181
@test typeof(Rₖ) == typeof(aₖ_dense)
8282
@test size(Rₖ) == size(aₖ_dense)
@@ -118,8 +118,8 @@ equalpairs = Dict( # these pairs of layers are all equal
118118
@test Rₖ₊₁ == l2(aₖ)
119119
Rₖ1 = similar(aₖ)
120120
Rₖ2 = similar(aₖ)
121-
@inferred lrp!(rule, l1, Rₖ1, aₖ, Rₖ₊₁)
122-
@inferred lrp!(rule, l2, Rₖ2, aₖ, Rₖ₊₁)
121+
@inferred lrp!(Rₖ1, rule, l1, aₖ, Rₖ₊₁)
122+
@inferred lrp!(Rₖ2, rule, l2, aₖ, Rₖ₊₁)
123123
@test Rₖ1 == Rₖ2
124124

125125
@test typeof(Rₖ1) == typeof(aₖ)
@@ -152,7 +152,7 @@ layers = Dict(
152152
@testset "$layername" begin
153153
Rₖ₊₁ = layer(aₖ)
154154
Rₖ = similar(aₖ)
155-
@inferred lrp!(rule, layer, Rₖ, aₖ, Rₖ₊₁)
155+
@inferred lrp!(Rₖ, rule, layer, aₖ, Rₖ₊₁)
156156

157157
@test typeof(Rₖ) == typeof(aₖ)
158158
@test size(Rₖ) == size(aₖ)
@@ -173,7 +173,7 @@ struct TestWrapper{T}
173173
end
174174
(w::TestWrapper)(x) = w.layer(x)
175175
modify_layer(r::AbstractLRPRule, w::TestWrapper) = modify_layer(r, w.layer)
176-
lrp!(rule::ZBoxRule, w::TestWrapper, Rₖ, aₖ, Rₖ₊₁) = lrp!(rule, w.layer, Rₖ, aₖ, Rₖ₊₁)
176+
lrp!(Rₖ, rule::ZBoxRule, w::TestWrapper, aₖ, Rₖ₊₁) = lrp!(Rₖ, rule, w.layer, aₖ, Rₖ₊₁)
177177

178178
layers = Dict(
179179
"Conv" => (Conv((3, 3), 2 => 4; init=pseudorandn), aₖ),
@@ -188,7 +188,7 @@ layers = Dict(
188188
wrapped_layer = TestWrapper(layer)
189189
Rₖ₊₁ = wrapped_layer(aₖ)
190190
Rₖ = similar(aₖ)
191-
@inferred lrp!(rule, wrapped_layer, Rₖ, aₖ, Rₖ₊₁)
191+
@inferred lrp!(Rₖ, rule, wrapped_layer, aₖ, Rₖ₊₁)
192192

193193
@test typeof(Rₖ) == typeof(aₖ)
194194
@test size(Rₖ) == size(aₖ)

0 commit comments

Comments
 (0)