2
2
abstract type AbstractLRPRule end
3
3
4
4
# Generic LRP rule. Since it uses autodiff, it is used as a fallback for layer types without custom implementation.
5
- function lrp! (rule:: R , layer:: L , Rₖ , aₖ, Rₖ₊₁) where {R<: AbstractLRPRule ,L}
6
- lrp_autodiff! (rule, layer, Rₖ , aₖ, Rₖ₊₁)
5
+ function lrp! (Rₖ, rule:: R , layer:: L , aₖ, Rₖ₊₁) where {R<: AbstractLRPRule ,L}
6
+ lrp_autodiff! (Rₖ, rule, layer , aₖ, Rₖ₊₁)
7
7
return nothing
8
8
end
9
9
10
10
function lrp_autodiff! (
11
- rule :: R , layer :: L , Rₖ :: T1 , aₖ:: T1 , Rₖ₊₁:: T2
11
+ Rₖ :: T1 , rule :: R , layer :: L , aₖ:: T1 , Rₖ₊₁:: T2
12
12
) where {R<: AbstractLRPRule ,L,T1,T2}
13
13
layerᵨ = modify_layer (rule, layer)
14
14
c:: T1 = only (
@@ -23,21 +23,21 @@ function lrp_autodiff!(
23
23
end
24
24
25
25
# For linear layer types such as Dense layers, using autodiff is overkill.
26
- function lrp! (rule:: R , layer:: Dense , Rₖ , aₖ, Rₖ₊₁) where {R<: AbstractLRPRule }
27
- lrp_dense! (rule, layer, Rₖ , aₖ, Rₖ₊₁)
26
+ function lrp! (Rₖ, rule:: R , layer:: Dense , aₖ, Rₖ₊₁) where {R<: AbstractLRPRule }
27
+ lrp_dense! (Rₖ, rule, layer , aₖ, Rₖ₊₁)
28
28
return nothing
29
29
end
30
30
31
- function lrp_dense! (rule:: R , l, Rₖ , aₖ, Rₖ₊₁) where {R<: AbstractLRPRule }
31
+ function lrp_dense! (Rₖ, rule:: R , l, aₖ, Rₖ₊₁) where {R<: AbstractLRPRule }
32
32
ρW, ρb = modify_params (rule, get_params (l)... )
33
33
ãₖ₊₁ = modify_denominator (rule, ρW * aₖ .+ ρb)
34
34
@tullio Rₖ[j, b] = aₖ[j, b] * ρW[k, j] / ãₖ₊₁[k, b] * Rₖ₊₁[k, b]
35
35
return nothing
36
36
end
37
37
38
38
# Other special cases that are dispatched on layer type:
39
- lrp! (:: AbstractLRPRule , :: DropoutLayer , Rₖ , aₖ, Rₖ₊₁) = (Rₖ .= Rₖ₊₁)
40
- lrp! (:: AbstractLRPRule , :: ReshapingLayer , Rₖ , aₖ, Rₖ₊₁) = (Rₖ .= reshape (Rₖ₊₁, size (aₖ)))
39
+ lrp! (Rₖ, :: AbstractLRPRule , :: DropoutLayer , aₖ, Rₖ₊₁) = (Rₖ .= Rₖ₊₁)
40
+ lrp! (Rₖ, :: AbstractLRPRule , :: ReshapingLayer , aₖ, Rₖ₊₁) = (Rₖ .= reshape (Rₖ₊₁, size (aₖ)))
41
41
42
42
# To implement new rules, we can define two custom functions `modify_params` and `modify_denominator`.
43
43
# If this isn't done, the following fallbacks are used by default:
@@ -75,7 +75,7 @@ Constructor for LRP-0 rule. Commonly used on upper layers.
75
75
struct ZeroRule <: AbstractLRPRule end
76
76
77
77
"""
78
- GammaRule(; γ=0.25)
78
+ GammaRule([ γ=0.25] )
79
79
80
80
Constructor for LRP-``γ`` rule. Commonly used on lower layers.
81
81
@@ -84,16 +84,17 @@ Arguments:
84
84
"""
85
85
struct GammaRule{T} <: AbstractLRPRule
86
86
γ:: T
87
- GammaRule (; γ= 0.25 ) = new {Float32} (γ)
87
+ GammaRule (γ= 0.25f0 ) = new {Float32} (γ)
88
88
end
89
89
function modify_params (r:: GammaRule , W, b)
90
- ρW = W + r. γ * relu .(W)
91
- ρb = b + r. γ * relu .(b)
90
+ T = eltype (W)
91
+ ρW = W + convert (T, r. γ) * relu .(W)
92
+ ρb = b + convert (T, r. γ) * relu .(b)
92
93
return ρW, ρb
93
94
end
94
95
95
96
"""
96
- EpsilonRule(; ϵ=1f-6 )
97
+ EpsilonRule([ϵ=1.0f-6] )
97
98
98
99
Constructor for LRP-``ϵ`` rule. Commonly used on middle layers.
99
100
@@ -102,7 +103,7 @@ Arguments:
102
103
"""
103
104
struct EpsilonRule{T} <: AbstractLRPRule
104
105
ϵ:: T
105
- EpsilonRule (; ϵ= 1.0f-6 ) = new {Float32} (ϵ)
106
+ EpsilonRule (ϵ= 1.0f-6 ) = new {Float32} (ϵ)
106
107
end
107
108
modify_denominator (r:: EpsilonRule , d) = stabilize_denom (d, r. ϵ)
108
109
@@ -122,8 +123,8 @@ struct ZBoxRule{T} <: AbstractLRPRule
122
123
end
123
124
124
125
# The ZBoxRule requires its own implementation of relevance propagation.
125
- lrp! (r:: ZBoxRule , layer:: Dense , Rₖ, aₖ, Rₖ₊₁) = lrp_zbox! (r, layer, Rₖ , aₖ, Rₖ₊₁)
126
- lrp! (r:: ZBoxRule , layer:: Conv , Rₖ, aₖ, Rₖ₊₁) = lrp_zbox! (r, layer, Rₖ , aₖ, Rₖ₊₁)
126
+ lrp! (Rₖ, r:: ZBoxRule , layer:: Dense , aₖ, Rₖ₊₁) = lrp_zbox! (Rₖ, r, layer , aₖ, Rₖ₊₁)
127
+ lrp! (Rₖ, r:: ZBoxRule , layer:: Conv , aₖ, Rₖ₊₁) = lrp_zbox! (Rₖ, r, layer , aₖ, Rₖ₊₁)
127
128
128
129
_zbox_bound (T, c:: Real , in_size) = fill (convert (T, c), in_size)
129
130
function _zbox_bound (T, A:: AbstractArray , in_size)
@@ -135,7 +136,7 @@ function _zbox_bound(T, A::AbstractArray, in_size)
135
136
return convert .(T, A)
136
137
end
137
138
138
- function lrp_zbox! (r :: ZBoxRule , layer :: L , Rₖ :: T1 , aₖ:: T1 , Rₖ₊₁:: T2 ) where {L,T1,T2}
139
+ function lrp_zbox! (Rₖ :: T1 , r :: ZBoxRule , layer :: L , aₖ:: T1 , Rₖ₊₁:: T2 ) where {L,T1,T2}
139
140
T = eltype (aₖ)
140
141
in_size = size (aₖ)
141
142
l = _zbox_bound (T, r. low, in_size)
0 commit comments