18
18
abstract type AbstractLRPRule end
19
19
20
20
# This is the generic relevance propagation rule which is used for the 0, γ and ϵ rules.
21
- # It can be extended for new rules via `modify_denominator` and `modify_layer`,
22
- # which in turn uses `modify_params` .
21
+ # It can be extended for new rules via `modify_denominator` and `modify_params`.
22
+ # Since it uses autodiff, it is used as a fallback for layer types without custom implementation .
23
23
function (rule:: AbstractLRPRule )(layer, aₖ, Rₖ₊₁)
24
- layerᵨ = modify_layer (rule, layer)
24
+ layerᵨ = _modify_layer (rule, layer)
25
25
function fwpass (a)
26
26
z = layerᵨ (a)
27
27
s = Zygote. dropgrad (Rₖ₊₁ ./ modify_denominator (rule, z))
28
28
return z ⋅ s
29
29
end
30
- c = gradient (fwpass, aₖ)[1 ]
31
- Rₖ = aₖ .* c
32
- return Rₖ
30
+ return aₖ .* gradient (fwpass, aₖ)[1 ] # Rₖ
33
31
end
34
32
35
33
# Special cases are dispatched on layer type:
36
- (rule :: AbstractLRPRule )(:: DropoutLayer , aₖ, Rₖ₊₁) = Rₖ₊₁
37
- (rule :: AbstractLRPRule )(:: ReshapingLayer , aₖ, Rₖ₊₁) = reshape (Rₖ₊₁, size (aₖ))
34
+ (:: AbstractLRPRule )(:: DropoutLayer , aₖ, Rₖ₊₁) = Rₖ₊₁
35
+ (:: AbstractLRPRule )(:: ReshapingLayer , aₖ, Rₖ₊₁) = reshape (Rₖ₊₁, size (aₖ))
38
36
37
+ # To implement new rules, we can define two custom functions `modify_params` and `modify_denominator`.
38
+ # If this isn't done, the following fallbacks are used by default:
39
39
"""
40
- modify_layer(rule, layer)
41
-
42
- Applies `modify_params` to layer if it has parameters
43
- """
44
- modify_layer (:: AbstractLRPRule , l) = l # skip layers without params
45
- function modify_layer (rule:: AbstractLRPRule , l:: Union{Dense,Conv} )
46
- W, b = get_weights (l)
47
- ρW, ρb = modify_params (rule, W, b)
48
- return set_weights (l, ρW, ρb)
49
- end
50
-
51
- """
52
- modify_params!(rule, W, b)
40
+ modify_params(rule, W, b)
53
41
54
42
Function that modifies weights and biases before applying relevance propagation.
55
43
"""
56
44
modify_params (:: AbstractLRPRule , W, b) = (W, b) # general fallback
57
45
58
46
"""
59
- modify_denominator!(d, rule )
47
+ modify_denominator(rule, d )
60
48
61
49
Function that modifies zₖ on the forward pass, e.g. for numerical stability.
62
50
"""
63
51
modify_denominator (:: AbstractLRPRule , d) = stabilize_denom (d; eps= 1.0f-9 ) # general fallback
64
52
53
+ # This helper function applies `modify_params`:
54
+ _modify_layer (:: AbstractLRPRule , layer) = layer # skip layers without modify_params
55
+ function _modify_layer (rule:: AbstractLRPRule , layer:: Union{Dense,Conv} )
56
+ return set_params (layer, modify_params (rule, get_params (layer)... )... )
57
+ end
58
+
65
59
"""
66
60
ZeroRule()
67
61
@@ -111,11 +105,11 @@ struct ZBoxRule <: AbstractLRPRule end
111
105
112
106
# The ZBoxRule requires its own implementation of relevance propagation.
113
107
function (rule:: ZBoxRule )(layer:: Union{Dense,Conv} , aₖ, Rₖ₊₁)
114
- layer, layer⁺, layer⁻ = modify_layer (rule, layer)
108
+ W, b = get_params (layer)
109
+ l, h = fill .(extrema (aₖ), (size (aₖ),))
115
110
116
- onemat = ones (eltype (aₖ), size (aₖ))
117
- l = onemat * minimum (aₖ)
118
- h = onemat * maximum (aₖ)
111
+ layer⁺ = set_params (layer, max .(0 , W), max .(0 , b)) # W⁺, b⁺
112
+ layer⁻ = set_params (layer, min .(0 , W), min .(0 , b)) # W⁻, b⁻
119
113
120
114
# Forward pass
121
115
function fwpass (a, l, h)
@@ -128,20 +122,5 @@ function (rule::ZBoxRule)(layer::Union{Dense,Conv}, aₖ, Rₖ₊₁)
128
122
return z ⋅ s
129
123
end
130
124
c, cₗ, cₕ = gradient (fwpass, aₖ, l, h) # w.r.t. three inputs
131
-
132
- # Backward pass
133
- Rₖ = aₖ .* c + l .* cₗ + h .* cₕ
134
- return Rₖ
135
- end
136
-
137
- function modify_layer (:: ZBoxRule , l:: Union{Dense,Conv} )
138
- W, b = get_weights (l)
139
- W⁻ = min .(0 , W)
140
- W⁺ = max .(0 , W)
141
- b⁻ = min .(0 , b)
142
- b⁺ = max .(0 , b)
143
-
144
- l⁺ = set_weights (l, W⁺, b⁺)
145
- l⁻ = set_weights (l, W⁻, b⁻)
146
- return l, l⁺, l⁻
125
+ return aₖ .* c + l .* cₗ + h .* cₕ # Rₖ from backward pass
147
126
end
0 commit comments