Skip to content

Commit e15b91e

Browse files
Add option to skip normalization of output layer relevance (#22)
* Move `normalize_output_relevance` into analyzer struct * Update CRP * Add tests --------- Co-authored-by: Maximilian Ernst <[email protected]>
1 parent ff1f850 commit e15b91e

File tree

4 files changed

+65
-14
lines changed

4 files changed

+65
-14
lines changed

src/crp.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,18 +35,18 @@ end
3535
function call_analyzer(
3636
input::AbstractArray{T,N}, crp::CRP, ns::AbstractOutputSelector
3737
) where {T,N}
38-
rules = crp.lrp.rules
39-
layers = crp.lrp.model.layers
40-
modified_layers = crp.lrp.modified_layers
38+
# Unpack internal LRP analyzer
39+
(; model, rules, modified_layers, normalize_output_relevance) = crp.lrp
40+
layers = model.layers
4141

4242
n_layers = length(layers)
4343
n_features = number_of_features(crp.features)
4444
batchsize = size(input, N)
4545

4646
# Forward pass
4747
as = get_activations(crp.lrp.model, input) # compute activations aᵏ for all layers k
48-
Rs = similar.(as) # allocate relevances Rᵏ for all layers k
49-
mask_output_neuron!(Rs[end], as[end], ns) # compute relevance Rᴺ of output layer N
48+
Rs = similar.(as) # allocate relevances Rᵏ for all layers k
49+
mask_output_neuron!(Rs[end], as[end], ns, normalize_output_relevance) # compute relevance Rᴺ of output layer N
5050

5151
# Allocate array for returned relevance, adding features to batch dimension
5252
R_return = similar(input, size(input)[1:(end - 1)]..., batchsize * n_features)

src/lrp.jl

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@ The analyzer can either be created by passing an array of LRP-rules
1111
or by passing a composite, see [`Composite`](@ref) for an example.
1212
1313
# Keyword arguments
14-
- `skip_checks::Bool`: Skip checks whether model is compatible with LRP and contains output softmax. Default is `false`.
15-
- `verbose::Bool`: Select whether the model checks should print a summary on failure. Default is `true`.
14+
- `normalize_output_relevance`: Selects whether output relevance should be set to 1 before applying LRP backward pass.
15+
Defaults to `true` to match literature. If `false`, values of output activations are used.
16+
- `skip_checks::Bool`: Skip checks whether model is compatible with LRP and contains output softmax. Defaults to `false`.
17+
- `verbose::Bool`: Select whether the model checks should print a summary on failure. Defaults to `true`.
1618
1719
# References
1820
[1] G. Montavon et al., Layer-Wise Relevance Propagation: An Overview
@@ -22,10 +24,16 @@ struct LRP{C<:Chain,R<:ChainTuple,L<:ChainTuple} <: AbstractXAIMethod
2224
model::C
2325
rules::R
2426
modified_layers::L
27+
normalize_output_relevance::Bool
2528

2629
# Construct LRP analyzer by assigning a rule to each layer
2730
function LRP(
28-
model::Chain, rules::ChainTuple; skip_checks=false, flatten=true, verbose=true
31+
model::Chain,
32+
rules::ChainTuple;
33+
normalize_output_relevance::Bool=true,
34+
skip_checks=false,
35+
flatten=true,
36+
verbose=true,
2937
)
3038
if flatten
3139
model = chainflatten(model)
@@ -37,7 +45,7 @@ struct LRP{C<:Chain,R<:ChainTuple,L<:ChainTuple} <: AbstractXAIMethod
3745
end
3846
modified_layers = get_modified_layers(rules, model)
3947
return new{typeof(model),typeof(rules),typeof(modified_layers)}(
40-
model, rules, modified_layers
48+
model, rules, modified_layers, normalize_output_relevance
4149
)
4250
end
4351
end
@@ -59,20 +67,25 @@ function call_analyzer(
5967
input::AbstractArray, lrp::LRP, ns::AbstractOutputSelector; layerwise_relevances=false
6068
)
6169
as = get_activations(lrp.model, input) # compute activations aᵏ for all layers k
62-
Rs = similar.(as) # allocate relevances Rᵏ for all layers k
63-
mask_output_neuron!(Rs[end], as[end], ns) # compute relevance Rᴺ of output layer N
64-
70+
Rs = similar.(as)
71+
mask_output_neuron!(Rs[end], as[end], ns, lrp.normalize_output_relevance) # compute relevance Rᴺ of output layer N
6572
lrp_backward_pass!(Rs, as, lrp.rules, lrp.model, lrp.modified_layers)
6673
extras = layerwise_relevances ? (layerwise_relevances=Rs,) : nothing
6774
return Explanation(first(Rs), input, last(as), ns(last(as)), :LRP, :attribution, extras)
6875
end
6976

7077
get_activations(model, input) = (input, Flux.activations(model, input)...)
7178

72-
function mask_output_neuron!(R_out, a_out, ns::AbstractOutputSelector)
79+
function mask_output_neuron!(
80+
R_out, a_out, ns::AbstractOutputSelector, normalize_output_relevance::Bool
81+
)
7382
fill!(R_out, 0)
7483
idx = ns(a_out)
75-
R_out[idx] .= 1
84+
if normalize_output_relevance
85+
R_out[idx] .= 1
86+
else
87+
R_out[idx] .= a_out[idx]
88+
end
7689
return R_out
7790
end
7891

test/test_batches.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,26 @@ for (name, method) in ANALYZERS
3939
@test expl2_bd.val expl_batch.val[:, 2]
4040
end
4141
end
42+
43+
@testset "Normalized output relevance" begin
44+
analyzer1 = LRP(model)
45+
analyzer2 = LRP(model; normalize_output_relevance=false)
46+
47+
e1 = analyze(input_batch, analyzer1)
48+
e2 = analyze(input_batch, analyzer2)
49+
v1_bd1 = e1.val[:, 1]
50+
v1_bd2 = e1.val[:, 2]
51+
v2_bd1 = e2.val[:, 1]
52+
v2_bd2 = e2.val[:, 2]
53+
54+
@test isapprox(sum(v1_bd1), 1, atol=0.05)
55+
@test isapprox(sum(v1_bd2), 1, atol=0.05)
56+
@test !isapprox(sum(v2_bd1), 1; atol=0.05)
57+
@test !isapprox(sum(v2_bd2), 1; atol=0.05)
58+
59+
ratio_bd1 = first(v1_bd1) / first(v2_bd1)
60+
ratio_bd2 = first(v1_bd2) / first(v2_bd2)
61+
@test !isapprox(ratio_bd1, ratio_bd2)
62+
@test v1_bd1 v2_bd1 * ratio_bd1
63+
@test v1_bd2 v2_bd2 * ratio_bd2
64+
end

test/test_cnn.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,3 +110,18 @@ end
110110
@test lwr1[1] lwr2[1]
111111
@test lwr1[end] lwr2[end]
112112
end
113+
114+
@testset "Normalized output relevance" begin
115+
analyzer1 = LRP(model)
116+
analyzer2 = LRP(model; normalize_output_relevance=false)
117+
118+
e1 = analyze(input, analyzer1)
119+
e2 = analyze(input, analyzer2)
120+
v1, v2 = e1.val, e2.val
121+
122+
@test isapprox(sum(v1), 1, atol=0.05)
123+
@test !isapprox(sum(v2), 1; atol=0.05)
124+
125+
ratio = first(v1) / first(v2)
126+
@test v1 v2 * ratio
127+
end

0 commit comments

Comments
 (0)