Skip to content

Commit 53b9f95

Browse files
authored
Add inplace updating lrp! rule calls and reuse gradient code (#38)
* Change `lrp` to `lrp!` that inplace updates Rₖ * Faster LRP preallocation * Reuse gradient method code with `gradient_wrt_input` * Drop precompilation
1 parent c778bba commit 53b9f95

File tree

7 files changed

+76
-99
lines changed

7 files changed

+76
-99
lines changed

benchmark/benchmarks.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using BenchmarkTools
22
using Flux
33
using ExplainabilityMethods
4-
import ExplainabilityMethods: _modify_layer, lrp
4+
import ExplainabilityMethods: _modify_layer, lrp!
55

66
on_CI = haskey(ENV, "GITHUB_ACTIONS")
77

@@ -44,7 +44,7 @@ struct TestWrapper{T}
4444
end
4545
(w::TestWrapper)(x) = w.layer(x)
4646
_modify_layer(r::AbstractLRPRule, w::TestWrapper) = _modify_layer(r, w.layer)
47-
lrp(rule::ZBoxRule, w::TestWrapper, aₖ, Rₖ₊₁) = lrp(rule, w.layer, aₖ, Rₖ₊₁)
47+
lrp!(rule::ZBoxRule, w::TestWrapper, Rₖ, aₖ, Rₖ₊₁) = lrp!(rule, w.layer, Rₖ, aₖ, Rₖ₊₁)
4848

4949
# generate input for conv layers
5050
insize = (64, 64, 3, 1)
@@ -69,11 +69,11 @@ rules = Dict(
6969
SUITE["Layer"] = BenchmarkGroup([k for k in keys(layers)])
7070
for (layername, (layer, aₖ)) in layers
7171
SUITE["Layer"][layername] = BenchmarkGroup([k for k in keys(rules)])
72-
72+
Rₖ = similar(aₖ)
7373
Rₖ₊₁ = layer(aₖ)
7474
for (rulename, rule) in rules
75-
SUITE["Layer"][layername][rulename] = @benchmarkable lrp(
76-
$(rule), $(layer), $(aₖ), $(Rₖ₊₁)
75+
SUITE["Layer"][layername][rulename] = @benchmarkable lrp!(
76+
$(rule), $(layer), $(Rₖ), $(aₖ), $(Rₖ₊₁)
7777
)
7878
end
7979
end

src/ExplainabilityMethods.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
module ExplainabilityMethods
22

3+
using Base.Iterators
34
using LinearAlgebra
45
using Flux
56
using Zygote
6-
using ColorSchemes
7-
using ImageCore
8-
using Base.Iterators
97
using Tullio
108

9+
# Heatmapping:
10+
using ImageCore
11+
using ColorSchemes
12+
13+
# Model checks:
1114
using Markdown
1215
using PrettyTables
1316

@@ -20,8 +23,6 @@ include("lrp_checks.jl")
2023
include("lrp_rules.jl")
2124
include("lrp.jl")
2225
include("heatmap.jl")
23-
include("precompile.jl")
24-
_precompile_()
2526

2627
export analyze
2728

@@ -34,7 +35,7 @@ export LRP, LRPZero, LRPEpsilon, LRPGamma
3435
export AbstractLRPRule
3536
export LRP_CONFIG
3637
export ZeroRule, EpsilonRule, GammaRule, ZBoxRule
37-
export lrp, modify_params, modify_denominator
38+
export modify_params, modify_denominator
3839
export check_model
3940

4041
# heatmapping

src/gradient.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
function gradient_wrt_input(model, input::T, output_neuron)::T where {T}
2+
return only(gradient((in) -> model(in)[output_neuron], input))
3+
end
4+
15
"""
26
Gradient(model)
37
@@ -10,8 +14,8 @@ end
1014
function (analyzer::Gradient)(input, ns::AbstractNeuronSelector)
1115
output = analyzer.model(input)
1216
output_neuron = ns(output)
13-
attr = gradient((in) -> analyzer.model(in)[output_neuron], input)[1]
14-
return Explanation(attr, output, output_neuron, :Gradient, Nothing)
17+
grad = gradient_wrt_input(analyzer.model, input, output_neuron)
18+
return Explanation(grad, output, output_neuron, :Gradient, Nothing)
1519
end
1620

1721
"""
@@ -29,6 +33,6 @@ end
2933
function (analyzer::InputTimesGradient)(input, ns::AbstractNeuronSelector)
3034
output = analyzer.model(input)
3135
output_neuron = ns(output)
32-
attr = input .* gradient((in) -> analyzer.model(in)[output_neuron], input)[1]
36+
attr = input .* gradient_wrt_input(analyzer.model, input, output_neuron)
3337
return Explanation(attr, output, output_neuron, :InputTimesGradient, Nothing)
3438
end

src/lrp.jl

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -47,23 +47,24 @@ LRPEpsilon(model::Chain; kwargs...) = LRP(model, EpsilonRule(); kwargs...)
4747
LRPGamma(model::Chain; kwargs...) = LRP(model, GammaRule(); kwargs...)
4848

4949
# The call to the LRP analyzer.
50-
function (analyzer::LRP)(input, ns::AbstractNeuronSelector; layerwise_relevances=false)
50+
function (analyzer::LRP)(
51+
input::AbstractArray{T}, ns::AbstractNeuronSelector; layerwise_relevances=false
52+
) where {T}
5153
layers = analyzer.model.layers
52-
acts = Vector{Any}([input])
53-
# Forward pass through layers, keeping track of activations
54-
for layer in layers
55-
append!(acts, [layer(acts[end])])
56-
end
57-
rels = deepcopy(acts) # allocate arrays
54+
# Compute layerwise activations on forward pass through model:
55+
acts = [input, Flux.activations(analyzer.model, input)...]
56+
57+
# Allocate array for layerwise relevances:
58+
rels = similar.(acts)
5859

5960
# Mask output neuron
6061
output_neuron = ns(acts[end])
61-
rels[end] *= 0
62+
rels[end] .= zero(T)
6263
rels[end][output_neuron] = acts[end][output_neuron]
6364

6465
# Backward pass through layers, applying LRP rules
6566
for (i, rule) in Iterators.reverse(enumerate(analyzer.rules))
66-
rels[i] .= lrp(rule, layers[i], acts[i], rels[i + 1])
67+
lrp!(rule, layers[i], rels[i], acts[i], rels[i + 1]) # inplace update rels[i]
6768
end
6869

6970
return Explanation(

src/lrp_rules.jl

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44
# can be implemented by dispatching on the functions `modify_params` & `modify_denominator`,
55
# which make use of the generalized LRP implementation shown in [1].
66
#
7-
# If the relevance propagation falls outside of this scheme, custom functions
7+
# If the relevance propagation falls outside of this scheme, custom low-level functions
88
# ```julia
9-
# (::MyLRPRule)(layer, aₖ, Rₖ₊₁) = ...
10-
# (::MyLRPRule)(layer::MyLayer, aₖ, Rₖ₊₁) = ...
11-
# (::AbstractLRPRule)(layer::MyLayer, aₖ, Rₖ₊₁) = ...
9+
# lrp!(::MyLRPRule, layer, Rₖ, aₖ, Rₖ₊₁) = ...
10+
# lrp!(::MyLRPRule, layer::MyLayer, Rₖ, aₖ, Rₖ₊₁) = ...
11+
# lrp!(::AbstractLRPRule, layer::MyLayer, Rₖ, aₖ, Rₖ₊₁) = ...
1212
# ```
13-
# that return `Rₖ` can be implemented.
13+
# that inplace-update `Rₖ` can be implemented.
1414
# This is used for the ZBoxRule and for faster computations on common layers.
1515
#
1616
# References:
@@ -22,12 +22,13 @@ abstract type AbstractLRPRule end
2222
# This is the generic relevance propagation rule which is used for the 0, γ and ϵ rules.
2323
# It can be extended for new rules via `modify_denominator` and `modify_params`.
2424
# Since it uses autodiff, it is used as a fallback for layer types without custom implementation.
25-
function lrp(rule::R, layer::L, aₖ, Rₖ₊₁) where {R<:AbstractLRPRule,L}
26-
return lrp_autodiff(rule, layer, aₖ, Rₖ₊₁)
25+
function lrp!(rule::R, layer::L, Rₖ, aₖ, Rₖ₊₁) where {R<:AbstractLRPRule,L}
26+
lrp_autodiff!(rule, layer, Rₖ, aₖ, Rₖ₊₁)
27+
return nothing
2728
end
2829

29-
function lrp_autodiff(
30-
rule::R, layer::L, aₖ::T1, Rₖ₊₁::T2
30+
function lrp_autodiff!(
31+
rule::R, layer::L, Rₖ::T1, aₖ::T1, Rₖ₊₁::T2
3132
) where {R<:AbstractLRPRule,L,T1,T2}
3233
layerᵨ = _modify_layer(rule, layer)
3334
c::T1 = only(
@@ -37,23 +38,26 @@ function lrp_autodiff(
3738
z s
3839
end,
3940
)
40-
return aₖ .* c # Rₖ
41+
Rₖ .= aₖ .* c
42+
return nothing
4143
end
4244

4345
# For linear layer types such as Dense layers, using autodiff is overkill.
44-
function lrp(rule::R, layer::Dense, aₖ, Rₖ₊₁) where {R<:AbstractLRPRule}
45-
return lrp_dense(rule, layer, aₖ, Rₖ₊₁)
46+
function lrp!(rule::R, layer::Dense, Rₖ, aₖ, Rₖ₊₁) where {R<:AbstractLRPRule}
47+
lrp_dense!(rule, layer, Rₖ, aₖ, Rₖ₊₁)
48+
return nothing
4649
end
4750

48-
function lrp_dense(rule::R, l, aₖ, Rₖ₊₁) where {R<:AbstractLRPRule}
51+
function lrp_dense!(rule::R, l, Rₖ, aₖ, Rₖ₊₁) where {R<:AbstractLRPRule}
4952
ρW, ρb = modify_params(rule, get_params(l)...)
5053
ãₖ₊₁ = modify_denominator(rule, ρW * aₖ + ρb)
51-
return @tullio Rₖ[j] := aₖ[j] * ρW[k, j] / ãₖ₊₁[k] * Rₖ₊₁[k]
54+
@tullio Rₖ[j] = aₖ[j] * ρW[k, j] / ãₖ₊₁[k] * Rₖ₊₁[k]
55+
return nothing
5256
end
5357

5458
# Other special cases that are dispatched on layer type:
55-
lrp(::AbstractLRPRule, ::DropoutLayer, aₖ, Rₖ₊₁) = Rₖ₊₁
56-
lrp(::AbstractLRPRule, ::ReshapingLayer, aₖ, Rₖ₊₁) = reshape(Rₖ₊₁, size(aₖ))
59+
lrp!(::AbstractLRPRule, ::DropoutLayer, Rₖ, aₖ, Rₖ₊₁) = (Rₖ .= Rₖ₊₁)
60+
lrp!(::AbstractLRPRule, ::ReshapingLayer, Rₖ, aₖ, Rₖ₊₁) = (Rₖ .= reshape(Rₖ₊₁, size(aₖ)))
5761

5862
# To implement new rules, we can define two custom functions `modify_params` and `modify_denominator`.
5963
# If this isn't done, the following fallbacks are used by default:
@@ -125,10 +129,10 @@ Commonly used on the first layer for pixel input.
125129
struct ZBoxRule <: AbstractLRPRule end
126130

127131
# The ZBoxRule requires its own implementation of relevance propagation.
128-
lrp(::ZBoxRule, layer::Dense, aₖ, Rₖ₊₁) = lrp_zbox(layer, aₖ, Rₖ₊₁)
129-
lrp(::ZBoxRule, layer::Conv, aₖ, Rₖ₊₁) = lrp_zbox(layer, aₖ, Rₖ₊₁)
132+
lrp!(::ZBoxRule, layer::Dense, Rₖ, aₖ, Rₖ₊₁) = lrp_zbox!(layer, Rₖ, aₖ, Rₖ₊₁)
133+
lrp!(::ZBoxRule, layer::Conv, Rₖ, aₖ, Rₖ₊₁) = lrp_zbox!(layer, Rₖ, aₖ, Rₖ₊₁)
130134

131-
function lrp_zbox(layer::L, aₖ::T1, Rₖ₊₁::T2) where {L,T1,T2}
135+
function lrp_zbox!(layer::L, Rₖ::T1, aₖ::T1, Rₖ₊₁::T2) where {L,T1,T2}
132136
W, b = get_params(layer)
133137
l, h = fill.(extrema(aₖ), (size(aₖ),))
134138

@@ -144,5 +148,6 @@ function lrp_zbox(layer::L, aₖ::T1, Rₖ₊₁::T2) where {L,T1,T2}
144148
s = Zygote.@ignore safedivide(Rₖ₊₁, z; eps=1e-9)
145149
z s
146150
end
147-
return aₖ .* c + l .* cₗ + h .* cₕ # Rₖ from backward pass
151+
Rₖ .= aₖ .* c + l .* cₗ + h .* cₕ
152+
return nothing
148153
end

src/precompile.jl

Lines changed: 0 additions & 44 deletions
This file was deleted.

test/test_rules.jl

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using ExplainabilityMethods
22
using ExplainabilityMethods: modify_params
3-
import ExplainabilityMethods: _modify_layer, lrp
3+
import ExplainabilityMethods: _modify_layer, lrp!
44
using Flux
55
using LinearAlgebra
66
using ReferenceTests
@@ -25,7 +25,9 @@ const RULES = Dict(
2525
Rₖ = [17 / 90, 316 / 675] # expected output
2626

2727
layer = Dense(W, b, relu)
28-
@test lrp(rule, layer, aₖ, Rₖ₊₁) Rₖ
28+
R̂ₖ = similar(aₖ) # will be inplace updated
29+
@inferred lrp!(rule, layer, R̂ₖ, aₖ, Rₖ₊₁)
30+
@test R̂ₖ Rₖ
2931

3032
## Pooling layer
3133
Rₖ₊₁ = Float32.([1 2; 3 4]//30)
@@ -38,7 +40,9 @@ const RULES = Dict(
3840
Rₖ = reshape(repeat(Rₖ, 1, 3), 3, 3, 3, 1)
3941

4042
layer = MaxPool((2, 2); stride=(1, 1))
41-
@test lrp(rule, layer, aₖ, Rₖ₊₁) Rₖ
43+
R̂ₖ = similar(aₖ) # will be inplace updated
44+
@inferred lrp!(rule, layer, R̂ₖ, aₖ, Rₖ₊₁)
45+
@test R̂ₖ Rₖ
4246
end
4347

4448
# Fixed pseudo-random numbers
@@ -69,7 +73,8 @@ layers = Dict(
6973
for (layername, layer) in layers
7074
@testset "$layername" begin
7175
Rₖ₊₁ = layer(aₖ)
72-
Rₖ = @inferred lrp(rule, layer, aₖ, Rₖ₊₁)
76+
Rₖ = similar(aₖ)
77+
@inferred lrp!(rule, layer, Rₖ, aₖ, Rₖ₊₁)
7378

7479
@test typeof(Rₖ) == typeof(aₖ)
7580
@test size(Rₖ) == size(aₖ)
@@ -110,14 +115,17 @@ equalpairs = Dict( # these pairs of layers are all equal
110115
l1, l2 = layers
111116
Rₖ₊₁ = l1(aₖ)
112117
@test Rₖ₊₁ == l2(aₖ)
113-
Rₖ = @inferred lrp(rule, l1, aₖ, Rₖ₊₁)
114-
@test Rₖ == lrp(rule, l2, aₖ, Rₖ₊₁)
118+
Rₖ1 = similar(aₖ)
119+
Rₖ2 = similar(aₖ)
120+
@inferred lrp!(rule, l1, Rₖ1, aₖ, Rₖ₊₁)
121+
@inferred lrp!(rule, l2, Rₖ2, aₖ, Rₖ₊₁)
122+
@test Rₖ1 == Rₖ2
115123

116-
@test typeof(Rₖ) == typeof(aₖ)
117-
@test size(Rₖ) == size(aₖ)
124+
@test typeof(Rₖ1) == typeof(aₖ)
125+
@test size(Rₖ1) == size(aₖ)
118126

119127
@test_reference "references/rules/$rulename/$layername.jld2" Dict(
120-
"R" => Rₖ
128+
"R" => Rₖ1
121129
) by = (r, a) -> isapprox(r["R"], a["R"]; rtol=0.02)
122130
end
123131
end
@@ -143,7 +151,8 @@ layers = Dict(
143151
for (layername, layer) in layers
144152
@testset "$layername" begin
145153
Rₖ₊₁ = layer(aₖ)
146-
Rₖ = @inferred lrp(rule, layer, aₖ, Rₖ₊₁)
154+
Rₖ = similar(aₖ)
155+
@inferred lrp!(rule, layer, Rₖ, aₖ, Rₖ₊₁)
147156

148157
@test typeof(Rₖ) == typeof(aₖ)
149158
@test size(Rₖ) == size(aₖ)
@@ -164,7 +173,7 @@ struct TestWrapper{T}
164173
end
165174
(w::TestWrapper)(x) = w.layer(x)
166175
_modify_layer(r::AbstractLRPRule, w::TestWrapper) = _modify_layer(r, w.layer)
167-
lrp(rule::ZBoxRule, w::TestWrapper, aₖ, Rₖ₊₁) = lrp(rule, w.layer, aₖ, Rₖ₊₁)
176+
lrp!(rule::ZBoxRule, w::TestWrapper, Rₖ, aₖ, Rₖ₊₁) = lrp!(rule, w.layer, Rₖ, aₖ, Rₖ₊₁)
168177

169178
layers = Dict(
170179
"Conv" => (Conv((3, 3), 2 => 4; init=pseudorandn), aₖ),
@@ -179,7 +188,8 @@ layers = Dict(
179188
@testset "$layername" begin
180189
wrapped_layer = TestWrapper(layer)
181190
Rₖ₊₁ = wrapped_layer(aₖ)
182-
Rₖ = @inferred lrp(rule, wrapped_layer, aₖ, Rₖ₊₁)
191+
Rₖ = similar(aₖ)
192+
@inferred lrp!(rule, wrapped_layer, Rₖ, aₖ, Rₖ₊₁)
183193

184194
@test typeof(Rₖ) == typeof(aₖ)
185195
@test size(Rₖ) == size(aₖ)

0 commit comments

Comments
 (0)