4
4
# can be implemented by dispatching on the functions `modify_params` & `modify_denominator`,
5
5
# which make use of the generalized LRP implementation shown in [1].
6
6
#
7
- # If the relevance propagation falls outside of this scheme, a custom function
7
+ # If the relevance propagation falls outside of this scheme, custom functions
8
8
# ```julia
9
9
# (::MyLRPRule)(layer, aₖ, Rₖ₊₁) = ...
10
+ # (::MyLRPRule)(layer::MyLayer, aₖ, Rₖ₊₁) = ...
11
+ # (::AbstractLRPRule)(layer::MyLayer, aₖ, Rₖ₊₁) = ...
10
12
# ```
11
- # can be implemented. This is used for the ZBoxRule.
13
+ # that return `Rₖ` can be implemented.
14
+ # This is used for the ZBoxRule and for faster computations on common layers.
12
15
#
13
16
# References:
14
17
# [1] G. Montavon et al., Layer-Wise Relevance Propagation: An Overview
15
- # [2] W. Samek et al., Explaining Deep Neural Networks and Beyond:
16
- # A Review of Methods and Applications
18
+ # [2] W. Samek et al., Explaining Deep Neural Networks and Beyond: A Review of Methods and Applications
17
19
18
20
abstract type AbstractLRPRule end
19
21
20
22
# This is the generic relevance propagation rule which is used for the 0, γ and ϵ rules.
21
23
# It can be extended for new rules via `modify_denominator` and `modify_params`.
22
24
# Since it uses autodiff, it is used as a fallback for layer types without custom implementation.
23
- function (rule:: AbstractLRPRule )(layer, aₖ, Rₖ₊₁)
25
+ (rule:: AbstractLRPRule )(layer, aₖ, Rₖ₊₁) = lrp_autodiff (rule, layer, aₖ, Rₖ₊₁)
26
+
27
+ function lrp_autodiff (rule, layer, aₖ, Rₖ₊₁)
24
28
layerᵨ = _modify_layer (rule, layer)
25
29
function fwpass (a)
26
30
z = layerᵨ (a)
@@ -30,7 +34,16 @@ function (rule::AbstractLRPRule)(layer, aₖ, Rₖ₊₁)
30
34
return aₖ .* gradient (fwpass, aₖ)[1 ] # Rₖ
31
35
end
32
36
33
- # Special cases are dispatched on layer type:
37
+ # For linear layer types such as Dense layers, using autodiff is overkill.
38
+ (rule:: AbstractLRPRule )(layer:: Dense , aₖ, Rₖ₊₁) = lrp_dense (rule, layer, aₖ, Rₖ₊₁)
39
+
40
+ function lrp_dense (rule, l, aₖ, Rₖ₊₁)
41
+ ρW, ρb = modify_params (rule, get_params (l)... )
42
+ ãₖ₊₁ = modify_denominator (rule, ρW * aₖ + ρb)
43
+ return @tullio Rₖ[j] := aₖ[j] * ρW[k, j] / ãₖ₊₁[k] * Rₖ₊₁[k]
44
+ end
45
+
46
+ # Other special cases that are dispatched on layer type:
34
47
(:: AbstractLRPRule )(:: DropoutLayer , aₖ, Rₖ₊₁) = Rₖ₊₁
35
48
(:: AbstractLRPRule )(:: ReshapingLayer , aₖ, Rₖ₊₁) = reshape (Rₖ₊₁, size (aₖ))
36
49
@@ -104,7 +117,10 @@ Commonly used on the first layer for pixel input.
104
117
struct ZBoxRule <: AbstractLRPRule end
105
118
106
119
# The ZBoxRule requires its own implementation of relevance propagation.
107
- function (rule:: ZBoxRule )(layer:: Union{Dense,Conv} , aₖ, Rₖ₊₁)
120
+ (rule:: ZBoxRule )(layer:: Dense , aₖ, Rₖ₊₁) = lrp_zbox (layer, aₖ, Rₖ₊₁)
121
+ (rule:: ZBoxRule )(layer:: Conv , aₖ, Rₖ₊₁) = lrp_zbox (layer, aₖ, Rₖ₊₁)
122
+
123
+ function lrp_zbox (layer, aₖ, Rₖ₊₁)
108
124
W, b = get_params (layer)
109
125
l, h = fill .(extrema (aₖ), (size (aₖ),))
110
126
0 commit comments