1
1
# https://adrhill.github.io/ExplainableAI.jl/stable/generated/advanced_lrp/#How-it-works-internally
2
2
abstract type AbstractLRPRule end
3
3
4
- # TODO : support all linear layers that use properties `weight` and `bias`
5
- const WeightBiasLayers = (Dense, Conv)
6
-
7
4
# Generic LRP rule. Since it uses autodiff, it is used as a fallback for layer types
8
5
# without custom implementations.
9
6
function lrp! (Rₖ, rule:: R , layer:: L , aₖ, Rₖ₊₁) where {R<: AbstractLRPRule ,L}
7
+ check_compat (rule, layer)
10
8
reset! = get_layer_resetter (rule, layer)
11
9
modify_layer! (rule, layer)
12
10
ãₖ₊₁, pullback = Zygote. pullback (layer, modify_input (rule, aₖ))
18
16
# To implement new rules, define the following custom functions:
19
17
# * `modify_input(rule, input)`
20
18
# * `modify_denominator(rule, d)`
19
+ # * `check_compat(rule, layer)`
21
20
# * `modify_param!(rule, param)` or `modify_layer!(rule, layer)`,
22
21
# the latter overriding the former
23
22
#
@@ -36,6 +35,17 @@ Modify denominator ``z`` for numerical stability on the forward pass.
36
35
"""
37
36
@inline modify_denominator (rule, d) = stabilize_denom (d, 1.0f-9 ) # general fallback
38
37
38
+ """
39
+ check_compat(rule, layer)
40
+
41
+ Check compatibility of a LRP-Rule with layer type.
42
+
43
+ ## Note
44
+ When implementing a custom `check_compat` function, return `nothing` if checks passed,
45
+ otherwise throw an `ArgumentError`.
46
+ """
47
+ @inline check_compat (rule, layer) = require_weight_and_bias (rule, layer)
48
+
39
49
"""
40
50
modify_layer!(rule, layer)
41
51
@@ -45,15 +55,12 @@ propagation.
45
55
## Note
46
56
When implementing a custom `modify_layer!` function, `modify_param!` will not be called.
47
57
"""
48
- modify_layer! (rule, layer) = nothing
49
- for L in WeightBiasLayers
50
- @eval function modify_layer! (rule:: R , layer:: $L ) where {R}
51
- if has_weight_and_bias (layer)
52
- modify_param! (rule, layer. weight)
53
- modify_bias! (rule, layer. bias)
54
- end
55
- return nothing
58
+ function modify_layer! (rule:: R , layer:: L ) where {R,L}
59
+ if has_weight_and_bias (layer)
60
+ modify_param! (rule, layer. weight)
61
+ modify_bias! (rule, layer. bias)
56
62
end
63
+ return nothing
57
64
end
58
65
59
66
"""
97
104
Constructor for LRP-0 rule. Commonly used on upper layers.
98
105
"""
99
106
struct ZeroRule <: AbstractLRPRule end
107
+ @inline check_compat (:: ZeroRule , layer) = nothing
108
+
109
+ """
110
+ EpsilonRule([ϵ=1.0f-6])
111
+
112
+ Constructor for LRP-``ϵ`` rule. Commonly used on middle layers.
113
+
114
+ Arguments:
115
+ - `ϵ`: Optional stabilization parameter, defaults to `1f-6`.
116
+ """
117
+ struct EpsilonRule{T} <: AbstractLRPRule
118
+ ϵ:: T
119
+ EpsilonRule (ϵ= 1.0f-6 ) = new {Float32} (ϵ)
120
+ end
121
+ modify_denominator (r:: EpsilonRule , d) = stabilize_denom (d, r. ϵ)
122
+ @inline check_compat (:: EpsilonRule , layer) = nothing
100
123
101
124
"""
102
125
GammaRule([γ=0.25])
@@ -115,20 +138,7 @@ function modify_param!(r::GammaRule, param::AbstractArray{T}) where {T}
115
138
param .+ = γ * relu .(param)
116
139
return nothing
117
140
end
118
-
119
- """
120
- EpsilonRule([ϵ=1.0f-6])
121
-
122
- Constructor for LRP-``ϵ`` rule. Commonly used on middle layers.
123
-
124
- Arguments:
125
- - `ϵ`: Optional stabilization parameter, defaults to `1f-6`.
126
- """
127
- struct EpsilonRule{T} <: AbstractLRPRule
128
- ϵ:: T
129
- EpsilonRule (ϵ= 1.0f-6 ) = new {Float32} (ϵ)
130
- end
131
- modify_denominator (r:: EpsilonRule , d) = stabilize_denom (d, r. ϵ)
141
+ @inline check_compat (rule:: GammaRule , layer) = require_weight_and_bias (rule, layer)
132
142
133
143
"""
134
144
ZBoxRule(low, high)
@@ -146,45 +156,44 @@ struct ZBoxRule{T} <: AbstractLRPRule
146
156
end
147
157
148
158
# The ZBoxRule requires its own implementation of relevance propagation.
149
- for L in WeightBiasLayers
150
- function lrp! (Rₖ, rule:: ZBoxRule , layer, aₖ, Rₖ₊₁)
151
- T = eltype (aₖ)
152
- l = zbox_input_augmentation (T, rule. low, size (aₖ))
153
- h = zbox_input_augmentation (T, rule. high, size (aₖ))
154
- reset! = get_layer_resetter (rule, layer)
159
+ function lrp! (Rₖ, rule:: ZBoxRule , layer:: L , aₖ, Rₖ₊₁) where {L}
160
+ require_weight_and_bias (rule, layer)
161
+ reset! = get_layer_resetter (rule, layer)
155
162
156
- # Compute pullback for W, b
157
- aₖ₊₁, pullback = Zygote . pullback (layer, aₖ )
163
+ l = zbox_input (aₖ, rule . low)
164
+ h = zbox_input (aₖ, rule . high )
158
165
159
- # Compute pullback for W⁺, b⁺
160
- modify_layer! (Val{:mask_positive }, layer)
161
- aₖ₊₁⁺, pullback⁺ = Zygote. pullback (layer, l)
162
- reset! ()
166
+ # Compute pullback for W, b
167
+ aₖ₊₁, pullback = Zygote. pullback (layer, aₖ)
163
168
164
- # Compute pullback for W⁻ , b⁻
165
- modify_layer! (Val{:mask_negative }, layer)
166
- aₖ₊₁⁻ , pullback⁻ = Zygote. pullback (layer, h )
167
- reset! ()
169
+ # Compute pullback for W⁺ , b⁺
170
+ modify_layer! (Val{:mask_positive }, layer)
171
+ aₖ₊₁⁺ , pullback⁺ = Zygote. pullback (layer, l )
172
+ reset! ()
168
173
169
- y = Rₖ₊₁ ./ modify_denominator (rule, aₖ₊₁ - aₖ₊₁⁺ - aₖ₊₁⁻)
170
- Rₖ .= aₖ .* only (pullback (y)) - l .* only (pullback⁺ (y)) - h .* only (pullback⁻ (y))
171
- return nothing
172
- end
174
+ # Compute pullback for W⁻, b⁻
175
+ modify_layer! (Val{:mask_negative }, layer)
176
+ aₖ₊₁⁻, pullback⁻ = Zygote. pullback (layer, h)
177
+ reset! ()
178
+
179
+ y = Rₖ₊₁ ./ modify_denominator (rule, aₖ₊₁ - aₖ₊₁⁺ - aₖ₊₁⁻)
180
+ Rₖ .= aₖ .* only (pullback (y)) - l .* only (pullback⁺ (y)) - h .* only (pullback⁻ (y))
181
+ return nothing
173
182
end
174
183
175
- const ZBOX_BOUNDS_MISMATCH = " ZBoxRule bounds should either be scalar or match input size. "
176
- function zbox_input_augmentation (T , A:: AbstractArray , in_size)
177
- size (A) != in_size && throw ( ArgumentError (ZBOX_BOUNDS_MISMATCH) )
184
+ zbox_input (in :: AbstractArray{T} , c :: Real ) where {T} = fill ( convert (T, c), size (in))
185
+ function zbox_input (in :: AbstractArray{T} , A:: AbstractArray ) where {T}
186
+ @assert size (A) == size (in )
178
187
return convert .(T, A)
179
188
end
180
- zbox_input_augmentation (T, c:: Real , in_size) = fill (convert (T, c), in_size)
181
189
182
- # Other special cases that are dispatched on layer type :
183
- const LRPRules = (ZeroRule, EpsilonRule, GammaRule, ZBoxRule )
184
- for R in LRPRules
190
+ # Special cases for rules that don't modify params for extra performance :
191
+ for R in (ZeroRule, EpsilonRule)
192
+ @eval get_layer_resetter ( :: $R , l) = Returns ( nothing )
185
193
@eval lrp! (Rₖ, :: $R , :: DropoutLayer , aₖ, Rₖ₊₁) = (Rₖ .= Rₖ₊₁)
186
194
@eval lrp! (Rₖ, :: $R , :: ReshapingLayer , aₖ, Rₖ₊₁) = (Rₖ .= reshape (Rₖ₊₁, size (aₖ)))
187
195
end
196
+
188
197
# Fast implementation for Dense layer using Tullio.jl's einsum notation:
189
198
for R in (ZeroRule, EpsilonRule, GammaRule)
190
199
@eval function lrp! (Rₖ, rule:: $R , layer:: Dense , aₖ, Rₖ₊₁)
@@ -196,7 +205,3 @@ for R in (ZeroRule, EpsilonRule, GammaRule)
196
205
return nothing
197
206
end
198
207
end
199
-
200
- # Rules that don't modify params can optionally be added here for extra performance
201
- get_layer_resetter (:: ZeroRule , l) = Returns (nothing )
202
- get_layer_resetter (:: EpsilonRule , l) = Returns (nothing )
0 commit comments