@@ -22,30 +22,38 @@ abstract type AbstractLRPRule end
22
22
# This is the generic relevance propagation rule which is used for the 0, γ and ϵ rules.
23
23
# It can be extended for new rules via `modify_denominator` and `modify_params`.
24
24
# Since it uses autodiff, it is used as a fallback for layer types without custom implementation.
25
- (rule:: AbstractLRPRule )(layer, aₖ, Rₖ₊₁) = lrp_autodiff (rule, layer, aₖ, Rₖ₊₁)
25
+ function lrp (rule:: R , layer:: L , aₖ, Rₖ₊₁) where {R<: AbstractLRPRule ,L}
26
+ return lrp_autodiff (rule, layer, aₖ, Rₖ₊₁)
27
+ end
26
28
27
- function lrp_autodiff (rule, layer, aₖ, Rₖ₊₁)
29
+ function lrp_autodiff (
30
+ rule:: R , layer:: L , aₖ:: T1 , Rₖ₊₁:: T2
31
+ ) where {R<: AbstractLRPRule ,L,T1,T2}
28
32
layerᵨ = _modify_layer (rule, layer)
29
- function fwpass (a)
30
- z = layerᵨ (a)
31
- s = Zygote. dropgrad (Rₖ₊₁ ./ modify_denominator (rule, z))
32
- return z ⋅ s
33
- end
34
- return aₖ .* gradient (fwpass, aₖ)[1 ] # Rₖ
33
+ c:: T1 = only (
34
+ gradient (aₖ) do a
35
+ z:: T2 = layerᵨ (a)
36
+ s = Zygote. @ignore Rₖ₊₁ ./ modify_denominator (rule, z)
37
+ z ⋅ s
38
+ end ,
39
+ )
40
+ return aₖ .* c # Rₖ
35
41
end
36
42
37
43
# For linear layer types such as Dense layers, using autodiff is overkill.
38
- (rule:: AbstractLRPRule )(layer:: Dense , aₖ, Rₖ₊₁) = lrp_dense (rule, layer, aₖ, Rₖ₊₁)
44
+ function lrp (rule:: R , layer:: Dense , aₖ, Rₖ₊₁) where {R<: AbstractLRPRule }
45
+ return lrp_dense (rule, layer, aₖ, Rₖ₊₁)
46
+ end
39
47
40
- function lrp_dense (rule, l, aₖ, Rₖ₊₁)
48
+ function lrp_dense (rule:: R , l, aₖ, Rₖ₊₁) where {R <: AbstractLRPRule }
41
49
ρW, ρb = modify_params (rule, get_params (l)... )
42
50
ãₖ₊₁ = modify_denominator (rule, ρW * aₖ + ρb)
43
51
return @tullio Rₖ[j] := aₖ[j] * ρW[k, j] / ãₖ₊₁[k] * Rₖ₊₁[k]
44
52
end
45
53
46
54
# Other special cases that are dispatched on layer type:
47
- (:: AbstractLRPRule )( :: DropoutLayer , aₖ, Rₖ₊₁) = Rₖ₊₁
48
- (:: AbstractLRPRule )( :: ReshapingLayer , aₖ, Rₖ₊₁) = reshape (Rₖ₊₁, size (aₖ))
55
+ lrp (:: AbstractLRPRule , :: DropoutLayer , aₖ, Rₖ₊₁) = Rₖ₊₁
56
+ lrp (:: AbstractLRPRule , :: ReshapingLayer , aₖ, Rₖ₊₁) = reshape (Rₖ₊₁, size (aₖ))
49
57
50
58
# To implement new rules, we can define two custom functions `modify_params` and `modify_denominator`.
51
59
# If this isn't done, the following fallbacks are used by default:
@@ -65,7 +73,7 @@ modify_denominator(::AbstractLRPRule, d) = stabilize_denom(d; eps=1.0f-9) # gene
65
73
66
74
# This helper function applies `modify_params`:
67
75
_modify_layer (:: AbstractLRPRule , layer) = layer # skip layers without modify_params
68
- function _modify_layer (rule:: AbstractLRPRule , layer:: Union{Dense,Conv} )
76
+ function _modify_layer (rule:: R , layer:: L ) where {R <: AbstractLRPRule ,L <: Union{Dense,Conv} }
69
77
return set_params (layer, modify_params (rule, get_params (layer)... )... )
70
78
end
71
79
@@ -117,26 +125,24 @@ Commonly used on the first layer for pixel input.
117
125
struct ZBoxRule <: AbstractLRPRule end
118
126
119
127
# The ZBoxRule requires its own implementation of relevance propagation.
120
- (rule :: ZBoxRule )( layer:: Dense , aₖ, Rₖ₊₁) = lrp_zbox (layer, aₖ, Rₖ₊₁)
121
- (rule :: ZBoxRule )( layer:: Conv , aₖ, Rₖ₊₁) = lrp_zbox (layer, aₖ, Rₖ₊₁)
128
+ lrp ( :: ZBoxRule , layer:: Dense , aₖ, Rₖ₊₁) = lrp_zbox (layer, aₖ, Rₖ₊₁)
129
+ lrp ( :: ZBoxRule , layer:: Conv , aₖ, Rₖ₊₁) = lrp_zbox (layer, aₖ, Rₖ₊₁)
122
130
123
- function lrp_zbox (layer, aₖ, Rₖ₊₁)
131
+ function lrp_zbox (layer:: L , aₖ:: T1 , Rₖ₊₁:: T2 ) where {L,T1,T2}
124
132
W, b = get_params (layer)
125
133
l, h = fill .(extrema (aₖ), (size (aₖ),))
126
134
127
135
layer⁺ = set_params (layer, max .(0 , W), max .(0 , b)) # W⁺, b⁺
128
136
layer⁻ = set_params (layer, min .(0 , W), min .(0 , b)) # W⁻, b⁻
129
137
130
- # Forward pass
131
- function fwpass (a, l, h)
132
- f = layer (a)
133
- f⁺ = layer⁺ (l)
134
- f⁻ = layer⁻ (h)
138
+ c:: T1 , cₗ:: T1 , cₕ:: T1 = gradient (aₖ, l, h) do a, l, h
139
+ f:: T2 = layer (a)
140
+ f⁺:: T2 = layer⁺ (l)
141
+ f⁻:: T2 = layer⁻ (h)
135
142
136
143
z = f - f⁺ - f⁻
137
- s = Zygote. dropgrad ( safedivide (Rₖ₊₁, z; eps= 1e-9 ) )
138
- return z ⋅ s
144
+ s = Zygote. @ignore safedivide (Rₖ₊₁, z; eps= 1e-9 )
145
+ z ⋅ s
139
146
end
140
- c, cₗ, cₕ = gradient (fwpass, aₖ, l, h) # w.r.t. three inputs
141
147
return aₖ .* c + l .* cₗ + h .* cₕ # Rₖ from backward pass
142
148
end
0 commit comments