Skip to content

Commit d384404

Browse files
authored
Add LRP ZPlusRule (#88)
1 parent 509b8b5 commit d384404

File tree

13 files changed

+49
-5
lines changed

13 files changed

+49
-5
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ Currently, the following analyzers are implemented:
8484
├── WSquareRule
8585
├── FlatRule
8686
├── ZBoxRule
87+
├── ZPlusRule
8788
├── AlphaBetaRule
8889
└── PassRule
8990
```

docs/literate/example.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ mosaic(heatmap(batch, analyzer, 1); nrow=10)
117117
# ├── WSquareRule
118118
# ├── FlatRule
119119
# ├── ZBoxRule
120+
# ├── ZPlusRule
120121
# ├── AlphaBetaRule
121122
# └── PassRule
122123
# ```

src/ExplainableAI.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ export LRP
4747
export AbstractLRPRule
4848
export LRP_CONFIG
4949
export ZeroRule, EpsilonRule, GammaRule, WSquareRule, FlatRule, PassRule
50-
export ZBoxRule, AlphaBetaRule
50+
export ZBoxRule, ZPlusRule, AlphaBetaRule
5151
export modify_input, modify_denominator
5252
export modify_param!, modify_layer!
5353
export check_model

src/lrp/composite_presets.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ $(repr("text/plain", EpsilonPlus()))
3131
function EpsilonPlus(; epsilon=1.0f-6)
3232
return Composite(
3333
GlobalTypeRule(
34-
ConvLayer => AlphaBetaRule(1.0f0, 0.0f0), # TODO: replace with ZPlusRule
34+
ConvLayer => ZPlusRule(),
3535
Dense => EpsilonRule(epsilon),
3636
DropoutLayer => PassRule(),
3737
ReshapingLayer => PassRule(),
@@ -71,7 +71,7 @@ $(repr("text/plain", EpsilonPlusFlat()))
7171
function EpsilonPlusFlat(; epsilon=1.0f-6)
7272
return Composite(
7373
GlobalTypeRule(
74-
ConvLayer => AlphaBetaRule(1.0f0, 0.0f0), # TODO: replace with ZPlusRule
74+
ConvLayer => ZPlusRule(),
7575
Dense => EpsilonRule(epsilon),
7676
DropoutLayer => PassRule(),
7777
ReshapingLayer => PassRule(),

src/lrp/rules.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,38 @@ function lrp!(Rₖ, rule::AlphaBetaRule, layer::L, aₖ, Rₖ₊₁) where {L}
311311
return nothing
312312
end
313313

314+
"""
315+
ZPlusRule()
316+
317+
LRP-``z^{+}`` rule. Commonly used on lower layers.
318+
319+
Equivalent to `AlphaBetaRule(1.0f0, 0.0f0)`, but slightly faster.
320+
See also [`AlphaBetaRule`](@ref).
321+
322+
# References
323+
- [1] $REF_BACH_LRP
324+
- [2] $REF_MONTAVON_DTD
325+
"""
326+
struct ZPlusRule <: AbstractLRPRule end
327+
function lrp!(Rₖ, rule::ZPlusRule, layer::L, aₖ, Rₖ₊₁) where {L}
328+
require_weight_and_bias(rule, layer)
329+
reset! = get_layer_resetter(rule, layer)
330+
331+
aₖ⁺ = keep_positive(aₖ)
332+
aₖ⁻ = keep_negative(aₖ)
333+
334+
modify_layer!(Val(:keep_positive), layer)
335+
out_1, pullback_1 = Zygote.pullback(layer, aₖ⁺)
336+
reset!()
337+
modify_layer!(Val(:keep_negative_zero_bias), layer)
338+
out_2, pullback_2 = Zygote.pullback(layer, aₖ⁻)
339+
reset!()
340+
341+
y_α = Rₖ₊₁ ./ modify_denominator(rule, out_1 + out_2)
342+
Rₖ .= aₖ⁺ .* only(pullback_1(y_α)) + aₖ⁻ .* only(pullback_2(y_α))
343+
return nothing
344+
end
345+
314346
# Special cases for rules that don't modify params for extra performance:
315347
for R in (ZeroRule, EpsilonRule)
316348
for L in (DropoutLayer, ReshapingLayer)
1.34 KB
Binary file not shown.
1.34 KB
Binary file not shown.
1.34 KB
Binary file not shown.
941 Bytes
Binary file not shown.
941 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)