Skip to content

Commit 78cedfc

Browse files
authored
Improve LRP backward pass (#143)
* Move inner LRP loop into `lrp_backward_pass!` * Reduce amount of Unicode * Add Tullio to benchmark dependencies
1 parent 7a57f3a commit 78cedfc

File tree

3 files changed

+27
-24
lines changed

3 files changed

+27
-24
lines changed

benchmark/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ ExplainableAI = "4f1bc3e1-d60d-4ed0-9367-9bdff9846d3b"
44
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
55
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
66
PkgBenchmark = "32113eaa-f34f-5b0d-bd6c-c81e245fc73d"
7+
Tullio = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc"
78

89
[compat]
910
BenchmarkTools = "1"

benchmark/benchmarks.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using BenchmarkTools
22
using LoopVectorization
3+
using Tullio
34
using Flux
45
using ExplainableAI
56
using ExplainableAI: lrp!, modify_layer

src/lrp/lrp.jl

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,21 @@ LRP(model::Chain; kwargs...) = LRP(model, Composite(ZeroRule()); kwargs...)
4646
# Construct Chain-/ParallelTuple of rules by applying composite
4747
LRP(model::Chain, c::Composite; kwargs...) = LRP(model, lrp_rules(model, c); kwargs...)
4848

49-
get_activations(model, input) = [input, Flux.activations(model, input)...]
49+
get_activations(model, input) = (input, Flux.activations(model, input)...)
5050

51-
function mask_output_neuron!(Rᴺ, aᴺ, ns::AbstractNeuronSelector)
52-
fill!(Rᴺ, 0)
53-
idx = ns(aᴺ)
54-
Rᴺ[idx] .= 1
55-
return Rᴺ
51+
function mask_output_neuron!(R_out, a_out, ns::AbstractNeuronSelector)
52+
fill!(R_out, 0)
53+
idx = ns(a_out)
54+
R_out[idx] .= 1
55+
return R_out
56+
end
57+
58+
function lrp_backward_pass!(Rs, as, rules, layers, modified_layers)
59+
# Apply LRP rules in backward-pass, inplace-updating relevances `Rs[k]` = Rᵏ
60+
for k in length(layers):-1:1
61+
lrp!(Rs[k], rules[k], layers[k], modified_layers[k], as[k], Rs[k + 1])
62+
end
63+
return Rs
5664
end
5765

5866
# Call to the LRP analyzer
@@ -63,11 +71,7 @@ function (lrp::LRP)(
6371
Rs = similar.(as) # allocate relevances Rᵏ for all layers k
6472
mask_output_neuron!(Rs[end], as[end], ns) # compute relevance Rᴺ of output layer N
6573

66-
# Apply LRP rules in backward-pass, inplace-updating relevances `Rs[k]` = Rᵏ
67-
for k in length(lrp.model):-1:1
68-
lrp!(Rs[k], lrp.rules[k], lrp.model[k], lrp.modified_layers[k], as[k], Rs[k + 1])
69-
end
70-
74+
lrp_backward_pass!(Rs, as, lrp.rules, lrp.model, lrp.modified_layers)
7175
extras = layerwise_relevances ? (layerwise_relevances=Rs,) : nothing
7276
return Explanation(first(Rs), last(as), ns(last(as)), :LRP, extras)
7377
end
@@ -77,10 +81,7 @@ function lrp!(Rᵏ, rules::ChainTuple, chain::Chain, modified_chain::ChainTuple,
7781
Rs = similar.(as)
7882
last(Rs) .= Rᵏ⁺¹
7983

80-
# Apply LRP rules in backward-pass, inplace-updating relevances `Rs[i]`
81-
for i in length(chain):-1:1
82-
lrp!(Rs[i], rules[i], chain[i], modified_chain[i], as[i], Rs[i + 1])
83-
end
84+
lrp_backward_pass!(Rs, as, rules, chain, modified_chain)
8485
return Rᵏ .= first(Rs)
8586
end
8687

@@ -91,15 +92,15 @@ function lrp!(
9192
# according to the contribution aᵏ⁺¹ᵢ of branch i to the output activation aᵏ⁺¹:
9293
# Rᵏ⁺¹ᵢ = Rᵏ⁺¹ .* aᵏ⁺¹ᵢ ./ aᵏ⁺¹ = c .* aᵏ⁺¹ᵢ
9394

94-
aᵏ⁺¹s = [l(aᵏ) for l in parallel.layers] # aᵏ⁺¹ᵢ for each branch i
95-
c = Rᵏ⁺¹ ./ stabilize_denom(sum(aᵏ⁺¹s))
96-
Rᵏ⁺¹s = [c .* aᵏ⁺¹ᵢ for aᵏ⁺¹ᵢ in aᵏ⁺¹s] # Rᵏ⁺¹ᵢ for each branch i
97-
Rᵏs = [similar(aᵏ) for _ in parallel.layers] # pre-allocate output Rᵏᵢ for each branch i
95+
aᵏ⁺¹_parallel = [layer(aᵏ) for layer in parallel.layers] # aᵏ⁺¹ᵢ for each branch i
96+
c = Rᵏ⁺¹ ./ stabilize_denom(sum(aᵏ⁺¹_parallel))
97+
Rᵏ⁺¹_parallel = [c .* a for a in aᵏ⁺¹_parallel] # Rᵏ⁺¹ᵢ for each branch i
98+
Rᵏ_parallel = [similar(aᵏ) for _ in parallel.layers] # pre-allocate output Rᵏᵢ for each branch
9899

99-
for (Rᵏᵢ, ruleᵢ, layerᵢ, modified_layerᵢ, Rᵏ⁺¹ᵢ) in
100-
zip(Rᵏs, rules, parallel.layers, modified_parallel, Rᵏ⁺¹s)
101-
# In-place update Rᵏᵢ (and therefore Rᵏs)
102-
lrp!(Rᵏᵢ, ruleᵢ, layerᵢ, modified_layerᵢ, aᵏ, Rᵏ⁺¹ᵢ)
100+
for (Rᵏᵢ, rule, layer, modified_layer, Rᵏ⁺¹ᵢ) in
101+
zip(Rᵏ_parallel, rules, parallel.layers, modified_parallel, Rᵏ⁺¹_parallel)
102+
# In-place update Rᵏᵢ and therefore Rᵏ_parallel
103+
lrp!(Rᵏᵢ, rule, layer, modified_layer, aᵏ, Rᵏ⁺¹ᵢ)
103104
end
104-
return Rᵏ .= sum(Rᵏs)
105+
return Rᵏ .= sum(Rᵏ_parallel)
105106
end

0 commit comments

Comments
 (0)