Skip to content

Commit df4d73f

Browse files
authored
Add PassRule (#76)
* Add `PassRule` * Clean up rule doc strings and add references
1 parent 840eee3 commit df4d73f

File tree

13 files changed

+80
-49
lines changed

13 files changed

+80
-49
lines changed

README.md

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,11 @@ Currently, the following analyzers are implemented:
5757
├── SmoothGrad
5858
├── IntegratedGradients
5959
└── LRP
60-
├── LRPZero
61-
├── LRPEpsilon
62-
└── LRPGamma
60+
├── ZeroRule
61+
├── EpsilonRule
62+
├── GammaRule
63+
├── PassRule
64+
└── ZBoxRule
6365
```
6466

6567
One of the design goals of ExplainableAI.jl is extensibility.

docs/src/api.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,9 @@ InterpolationAugmentation
2424
## Rules
2525
```@docs
2626
ZeroRule
27-
GammaRule
2827
EpsilonRule
28+
GammaRule
29+
PassRule
2930
ZBoxRule
3031
```
3132

src/ExplainableAI.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ export LRP
4242
# LRP rules
4343
export AbstractLRPRule
4444
export LRP_CONFIG
45-
export ZeroRule, EpsilonRule, GammaRule, ZBoxRule
45+
export ZeroRule, EpsilonRule, GammaRule, ZBoxRule, PassRule
4646
export modify_input, modify_denominator
4747
export modify_param!, modify_layer!
4848
export check_model

src/lrp_rules.jl

Lines changed: 51 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,8 @@ Inplace-modify parameters before computing the relevance.
7272
@inline modify_param!(rule, param) = nothing # general fallback
7373

7474
# Useful presets:
75-
modify_param!(::Val{:mask_positive}, p) = (p .= max.(zero(eltype(p)), p), return nothing)
76-
modify_param!(::Val{:mask_negative}, p) = (p .= min.(zero(eltype(p)), p), return nothing)
75+
modify_param!(::Val{:mask_positive}, p) = p .= max.(zero(eltype(p)), p)
76+
modify_param!(::Val{:mask_negative}, p) = p .= min.(zero(eltype(p)), p)
7777

7878
# Internal wrapper functions for bias-free layers.
7979
@inline modify_bias!(rule::R, b) where {R} = modify_param!(rule, b)
@@ -101,18 +101,29 @@ end
101101
"""
102102
ZeroRule()
103103
104-
Constructor for LRP-0 rule. Commonly used on upper layers.
104+
LRP-0 rule. Commonly used on upper layers.
105+
106+
# References
107+
[1]: S. Bach et al., On Pixel-Wise Explanations for Non-Linear Classifier Decisions by
108+
Layer-Wise Relevance Propagation
105109
"""
106110
struct ZeroRule <: AbstractLRPRule end
107111
@inline check_compat(::ZeroRule, layer) = nothing
108112

113+
# Optimization to save allocations since weights don't need to be reset:
114+
get_layer_resetter(::ZeroRule, layer) = Returns(nothing)
115+
109116
"""
110117
EpsilonRule([ϵ=1.0f-6])
111118
112-
Constructor for LRP-``ϵ`` rule. Commonly used on middle layers.
119+
LRP-``ϵ`` rule. Commonly used on middle layers.
113120
114121
Arguments:
115122
- `ϵ`: Optional stabilization parameter, defaults to `1f-6`.
123+
124+
# References
125+
[1]: S. Bach et al., On Pixel-Wise Explanations for Non-Linear Classifier Decisions by
126+
Layer-Wise Relevance Propagation
116127
"""
117128
struct EpsilonRule{T} <: AbstractLRPRule
118129
ϵ::T
@@ -121,13 +132,19 @@ end
121132
modify_denominator(r::EpsilonRule, d) = stabilize_denom(d, r.ϵ)
122133
@inline check_compat(::EpsilonRule, layer) = nothing
123134

135+
# Optimization to save allocations since weights don't need to be reset:
136+
get_layer_resetter(::EpsilonRule, layer) = Returns(nothing)
137+
124138
"""
125139
GammaRule([γ=0.25])
126140
127-
Constructor for LRP-``γ`` rule. Commonly used on lower layers.
141+
LRP-``γ`` rule. Commonly used on lower layers.
128142
129143
Arguments:
130-
- `γ`: Optional multiplier for added positive weights, defaults to 0.25.
144+
- `γ`: Optional multiplier for added positive weights, defaults to `0.25`.
145+
146+
# References
147+
[1]: G. Montavon et al., Layer-Wise Relevance Propagation: An Overview
131148
"""
132149
struct GammaRule{T} <: AbstractLRPRule
133150
γ::T
@@ -140,15 +157,34 @@ function modify_param!(r::GammaRule, param::AbstractArray{T}) where {T}
140157
end
141158
@inline check_compat(rule::GammaRule, layer) = require_weight_and_bias(rule, layer)
142159

160+
"""
161+
PassRule()
162+
163+
Pass-through rule. Passes relevance through to the lower layer.
164+
Supports reshaping layers.
165+
"""
166+
struct PassRule <: AbstractLRPRule end
167+
function lrp!(Rₖ, ::PassRule, layer, aₖ, Rₖ₊₁)
168+
if size(aₖ) == size(Rₖ₊₁)
169+
Rₖ .= Rₖ₊₁
170+
end
171+
Rₖ .= reshape(Rₖ₊₁, size(aₖ))
172+
return nothing
173+
end
174+
# No extra checks as reshaping operation will throw an error if layer isn't compatible:
175+
@inline check_compat(::PassRule, layer) = nothing
176+
143177
"""
144178
ZBoxRule(low, high)
145179
146-
Constructor for LRP-``z^{\\mathcal{B}}``-rule.
147-
Commonly used on the first layer for pixel input.
180+
LRP-``z^{\\mathcal{B}}``-rule. Commonly used on the first layer for pixel input.
148181
149182
The parameters `low` and `high` should be set to the lower and upper bounds of the input features,
150183
e.g. `0.0` and `1.0` for raw image data.
151184
It is also possible to provide two arrays of that match the input size.
185+
186+
## References
187+
[1]: G. Montavon et al., Explaining nonlinear classification decisions with deep Taylor decomposition
152188
"""
153189
struct ZBoxRule{T} <: AbstractLRPRule
154190
low::T
@@ -194,6 +230,13 @@ for R in (ZeroRule, EpsilonRule)
194230
@eval lrp!(Rₖ, ::$R, ::ReshapingLayer, aₖ, Rₖ₊₁) = (Rₖ .= reshape(Rₖ₊₁, size(aₖ)))
195231
end
196232

233+
# Special cases for rules that don't modify params for extra performance:
234+
for R in (ZeroRule, EpsilonRule)
235+
for L in (DropoutLayer, ReshapingLayer)
236+
@eval lrp!(Rₖ, ::$R, l::$L, aₖ, Rₖ₊₁) = lrp!(Rₖ, PassRule(), l, aₖ, Rₖ₊₁)
237+
end
238+
end
239+
197240
# Fast implementation for Dense layer using Tullio.jl's einsum notation:
198241
for R in (ZeroRule, EpsilonRule, GammaRule)
199242
@eval function lrp!(Rₖ, rule::$R, layer::Dense, aₖ, Rₖ₊₁)

src/types.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ const ConvLayer = Union{Conv} # TODO: DepthwiseConv, ConvTranspose, CrossCor
66
const DropoutLayer = Union{Dropout,typeof(Flux.dropout),AlphaDropout}
77

88
"""Union type for reshaping layers such as `flatten`."""
9-
const ReshapingLayer = Union{typeof(Flux.flatten), typeof(Flux.MLUtils.flatten)}
9+
const ReshapingLayer = Union{typeof(Flux.flatten),typeof(Flux.MLUtils.flatten)}
1010

1111
"""Union type for max pooling layers."""
1212
const MaxPoolLayer = Union{MaxPool,AdaptiveMaxPool,GlobalMaxPool}

test/references/heatmaps/vgg11_LRPEpsilon.txt

Lines changed: 0 additions & 15 deletions
This file was deleted.
-589 KB
Binary file not shown.
-589 KB
Binary file not shown.
-589 KB
Binary file not shown.
-589 KB
Binary file not shown.

0 commit comments

Comments
 (0)