@@ -11,8 +11,10 @@ The analyzer can either be created by passing an array of LRP-rules
11
11
or by passing a composite, see [`Composite`](@ref) for an example.
12
12
13
13
# Keyword arguments
14
- - `skip_checks::Bool`: Skip checks whether model is compatible with LRP and contains output softmax. Default is `false`.
15
- - `verbose::Bool`: Select whether the model checks should print a summary on failure. Default is `true`.
14
+ - `normalize_output_relevance`: Selects whether output relevance should be set to 1 before applying LRP backward pass.
15
+ Defaults to `true` to match literature. If `false`, values of output activations are used.
16
+ - `skip_checks::Bool`: Skip checks whether model is compatible with LRP and contains output softmax. Defaults to `false`.
17
+ - `verbose::Bool`: Select whether the model checks should print a summary on failure. Defaults to `true`.
16
18
17
19
# References
18
20
[1] G. Montavon et al., Layer-Wise Relevance Propagation: An Overview
@@ -22,10 +24,16 @@ struct LRP{C<:Chain,R<:ChainTuple,L<:ChainTuple} <: AbstractXAIMethod
22
24
model:: C
23
25
rules:: R
24
26
modified_layers:: L
27
+ normalize_output_relevance:: Bool
25
28
26
29
# Construct LRP analyzer by assigning a rule to each layer
27
30
function LRP (
28
- model:: Chain , rules:: ChainTuple ; skip_checks= false , flatten= true , verbose= true
31
+ model:: Chain ,
32
+ rules:: ChainTuple ;
33
+ normalize_output_relevance:: Bool = true ,
34
+ skip_checks= false ,
35
+ flatten= true ,
36
+ verbose= true ,
29
37
)
30
38
if flatten
31
39
model = chainflatten (model)
@@ -37,7 +45,7 @@ struct LRP{C<:Chain,R<:ChainTuple,L<:ChainTuple} <: AbstractXAIMethod
37
45
end
38
46
modified_layers = get_modified_layers (rules, model)
39
47
return new {typeof(model),typeof(rules),typeof(modified_layers)} (
40
- model, rules, modified_layers
48
+ model, rules, modified_layers, normalize_output_relevance
41
49
)
42
50
end
43
51
end
@@ -59,20 +67,25 @@ function call_analyzer(
59
67
input:: AbstractArray , lrp:: LRP , ns:: AbstractOutputSelector ; layerwise_relevances= false
60
68
)
61
69
as = get_activations (lrp. model, input) # compute activations aᵏ for all layers k
62
- Rs = similar .(as) # allocate relevances Rᵏ for all layers k
63
- mask_output_neuron! (Rs[end ], as[end ], ns) # compute relevance Rᴺ of output layer N
64
-
70
+ Rs = similar .(as)
71
+ mask_output_neuron! (Rs[end ], as[end ], ns, lrp. normalize_output_relevance) # compute relevance Rᴺ of output layer N
65
72
lrp_backward_pass! (Rs, as, lrp. rules, lrp. model, lrp. modified_layers)
66
73
extras = layerwise_relevances ? (layerwise_relevances= Rs,) : nothing
67
74
return Explanation (first (Rs), input, last (as), ns (last (as)), :LRP , :attribution , extras)
68
75
end
69
76
70
77
get_activations (model, input) = (input, Flux. activations (model, input)... )
71
78
72
- function mask_output_neuron! (R_out, a_out, ns:: AbstractOutputSelector )
79
+ function mask_output_neuron! (
80
+ R_out, a_out, ns:: AbstractOutputSelector , normalize_output_relevance:: Bool
81
+ )
73
82
fill! (R_out, 0 )
74
83
idx = ns (a_out)
75
- R_out[idx] .= 1
84
+ if normalize_output_relevance
85
+ R_out[idx] .= 1
86
+ else
87
+ R_out[idx] .= a_out[idx]
88
+ end
76
89
return R_out
77
90
end
78
91
0 commit comments