4
4
# can be implemented by dispatching on the functions `modify_params` & `modify_denominator`,
5
5
# which make use of the generalized LRP implementation shown in [1].
6
6
#
7
- # If the relevance propagation falls outside of this scheme, custom functions
7
+ # If the relevance propagation falls outside of this scheme, custom low-level functions
8
8
# ```julia
9
- # (::MyLRPRule)( layer, aₖ, Rₖ₊₁) = ...
10
- # (::MyLRPRule)( layer::MyLayer, aₖ, Rₖ₊₁) = ...
11
- # (::AbstractLRPRule)( layer::MyLayer, aₖ, Rₖ₊₁) = ...
9
+ # lrp! (::MyLRPRule, layer, Rₖ , aₖ, Rₖ₊₁) = ...
10
+ # lrp! (::MyLRPRule, layer::MyLayer, Rₖ , aₖ, Rₖ₊₁) = ...
11
+ # lrp! (::AbstractLRPRule, layer::MyLayer, Rₖ , aₖ, Rₖ₊₁) = ...
12
12
# ```
13
- # that return `Rₖ` can be implemented.
13
+ # that inplace-update `Rₖ` can be implemented.
14
14
# This is used for the ZBoxRule and for faster computations on common layers.
15
15
#
16
16
# References:
@@ -22,12 +22,13 @@ abstract type AbstractLRPRule end
22
22
# This is the generic relevance propagation rule which is used for the 0, γ and ϵ rules.
23
23
# It can be extended for new rules via `modify_denominator` and `modify_params`.
24
24
# Since it uses autodiff, it is used as a fallback for layer types without custom implementation.
25
- function lrp (rule:: R , layer:: L , aₖ, Rₖ₊₁) where {R<: AbstractLRPRule ,L}
26
- return lrp_autodiff (rule, layer, aₖ, Rₖ₊₁)
25
+ function lrp! (rule:: R , layer:: L , Rₖ, aₖ, Rₖ₊₁) where {R<: AbstractLRPRule ,L}
26
+ lrp_autodiff! (rule, layer, Rₖ, aₖ, Rₖ₊₁)
27
+ return nothing
27
28
end
28
29
29
- function lrp_autodiff (
30
- rule:: R , layer:: L , aₖ:: T1 , Rₖ₊₁:: T2
30
+ function lrp_autodiff! (
31
+ rule:: R , layer:: L , Rₖ :: T1 , aₖ:: T1 , Rₖ₊₁:: T2
31
32
) where {R<: AbstractLRPRule ,L,T1,T2}
32
33
layerᵨ = _modify_layer (rule, layer)
33
34
c:: T1 = only (
@@ -37,23 +38,26 @@ function lrp_autodiff(
37
38
z ⋅ s
38
39
end ,
39
40
)
40
- return aₖ .* c # Rₖ
41
+ Rₖ .= aₖ .* c
42
+ return nothing
41
43
end
42
44
43
45
# For linear layer types such as Dense layers, using autodiff is overkill.
44
- function lrp (rule:: R , layer:: Dense , aₖ, Rₖ₊₁) where {R<: AbstractLRPRule }
45
- return lrp_dense (rule, layer, aₖ, Rₖ₊₁)
46
+ function lrp! (rule:: R , layer:: Dense , Rₖ, aₖ, Rₖ₊₁) where {R<: AbstractLRPRule }
47
+ lrp_dense! (rule, layer, Rₖ, aₖ, Rₖ₊₁)
48
+ return nothing
46
49
end
47
50
48
- function lrp_dense (rule:: R , l, aₖ, Rₖ₊₁) where {R<: AbstractLRPRule }
51
+ function lrp_dense! (rule:: R , l, Rₖ , aₖ, Rₖ₊₁) where {R<: AbstractLRPRule }
49
52
ρW, ρb = modify_params (rule, get_params (l)... )
50
53
ãₖ₊₁ = modify_denominator (rule, ρW * aₖ + ρb)
51
- return @tullio Rₖ[j] := aₖ[j] * ρW[k, j] / ãₖ₊₁[k] * Rₖ₊₁[k]
54
+ @tullio Rₖ[j] = aₖ[j] * ρW[k, j] / ãₖ₊₁[k] * Rₖ₊₁[k]
55
+ return nothing
52
56
end
53
57
54
58
# Other special cases that are dispatched on layer type:
55
- lrp (:: AbstractLRPRule , :: DropoutLayer , aₖ, Rₖ₊₁) = Rₖ₊₁
56
- lrp (:: AbstractLRPRule , :: ReshapingLayer , aₖ, Rₖ₊₁) = reshape (Rₖ₊₁, size (aₖ))
59
+ lrp! (:: AbstractLRPRule , :: DropoutLayer , Rₖ, aₖ, Rₖ₊₁) = (Rₖ . = Rₖ₊₁)
60
+ lrp! (:: AbstractLRPRule , :: ReshapingLayer , Rₖ, aₖ, Rₖ₊₁) = (Rₖ . = reshape (Rₖ₊₁, size (aₖ) ))
57
61
58
62
# To implement new rules, we can define two custom functions `modify_params` and `modify_denominator`.
59
63
# If this isn't done, the following fallbacks are used by default:
@@ -125,10 +129,10 @@ Commonly used on the first layer for pixel input.
125
129
struct ZBoxRule <: AbstractLRPRule end
126
130
127
131
# The ZBoxRule requires its own implementation of relevance propagation.
128
- lrp (:: ZBoxRule , layer:: Dense , aₖ, Rₖ₊₁) = lrp_zbox (layer, aₖ, Rₖ₊₁)
129
- lrp (:: ZBoxRule , layer:: Conv , aₖ, Rₖ₊₁) = lrp_zbox (layer, aₖ, Rₖ₊₁)
132
+ lrp! (:: ZBoxRule , layer:: Dense , Rₖ, aₖ, Rₖ₊₁) = lrp_zbox! (layer, Rₖ , aₖ, Rₖ₊₁)
133
+ lrp! (:: ZBoxRule , layer:: Conv , Rₖ, aₖ, Rₖ₊₁) = lrp_zbox! (layer, Rₖ , aₖ, Rₖ₊₁)
130
134
131
- function lrp_zbox (layer:: L , aₖ:: T1 , Rₖ₊₁:: T2 ) where {L,T1,T2}
135
+ function lrp_zbox! (layer:: L , Rₖ :: T1 , aₖ:: T1 , Rₖ₊₁:: T2 ) where {L,T1,T2}
132
136
W, b = get_params (layer)
133
137
l, h = fill .(extrema (aₖ), (size (aₖ),))
134
138
@@ -144,5 +148,6 @@ function lrp_zbox(layer::L, aₖ::T1, Rₖ₊₁::T2) where {L,T1,T2}
144
148
s = Zygote. @ignore safedivide (Rₖ₊₁, z; eps= 1e-9 )
145
149
z ⋅ s
146
150
end
147
- return aₖ .* c + l .* cₗ + h .* cₕ # Rₖ from backward pass
151
+ Rₖ .= aₖ .* c + l .* cₗ + h .* cₕ
152
+ return nothing
148
153
end
0 commit comments