Skip to content

Commit 0963afa

Browse files
authored
Add AlphaBetaRule (#78)
1 parent 1d0fbf6 commit 0963afa

File tree

11 files changed

+65
-13
lines changed

11 files changed

+65
-13
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,12 @@ Currently, the following analyzers are implemented:
6363
├── WSquareRule
6464
├── FlatRule
6565
├── ZBoxRule
66+
├── AlphaBetaRule
6667
└── PassRule
6768
```
6869

6970
One of the design goals of ExplainableAI.jl is extensibility.
70-
Individual LRP rules like `ZeroRule`, `EpsilonRule`, `GammaRule` and `ZBoxRule` [can be composed][docs-composites] and are easily extended by [custom rules][docs-custom-rules].
71+
Individual LRP rules [can be composed][docs-composites] and are easily extended by [custom rules][docs-custom-rules].
7172

7273
## Roadmap
7374
In the future, we would like to include:

benchmark/benchmarks.jl

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,11 @@ model = flatten_model(strip_softmax(vgg11.layers))
1313
T = Float32
1414
img = rand(MersenneTwister(123), T, (224, 224, 3, 1))
1515

16-
# Benchmark custom LRP composite
17-
function LRPCustom(model::Chain)
18-
return LRP(
19-
model,
20-
[ZBoxRule(zero(T), oneunit(T)), repeat([GammaRule()], length(model.layers) - 1)...],
21-
)
22-
end
23-
2416
# Use one representative algorithm of each type
2517
algs = Dict(
2618
"Gradient" => Gradient,
2719
"InputTimesGradient" => InputTimesGradient,
28-
"LRPZero" => LRP,
29-
"LRPCustom" => LRPCustom, #modifies weights
20+
"LRP" => LRP,
3021
"SmoothGrad" => model -> SmoothGrad(model, 10),
3122
"IntegratedGradients" => model -> IntegratedGradients(model, 10),
3223
)
@@ -53,7 +44,6 @@ out_dense = 100
5344
aₖ = randn(T, insize)
5445

5546
layers = Dict(
56-
"MaxPool" => (MaxPool((3, 3); pad=0), aₖ),
5747
"Conv" => (Conv((3, 3), 3 => 2), aₖ),
5848
"Dense" => (Dense(in_dense, out_dense, relu), randn(T, in_dense, 1)),
5949
)
@@ -62,6 +52,9 @@ rules = Dict(
6252
"EpsilonRule" => EpsilonRule(),
6353
"GammaRule" => GammaRule(),
6454
"ZBoxRule" => ZBoxRule(zero(T), oneunit(T)),
55+
"FlatRule" => FlatRule(),
56+
"WSquareRule" => WSquareRule(),
57+
"AlphaBetaRule" => AlphaBetaRule(),
6558
)
6659

6760
SUITE["Layer"] = BenchmarkGroup([k for k in keys(layers)])

docs/src/api.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ ZeroRule
2727
EpsilonRule
2828
GammaRule
2929
WSquareRule
30+
AlphaBetaRule
3031
FlatRule
3132
ZBoxRule
3233
PassRule

src/ExplainableAI.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ export LRP
4242
# LRP rules
4343
export AbstractLRPRule
4444
export LRP_CONFIG
45-
export ZeroRule, EpsilonRule, GammaRule, WSquareRule, FlatRule, ZBoxRule, PassRule
45+
export ZeroRule, EpsilonRule, GammaRule, WSquareRule, FlatRule, PassRule
46+
export ZBoxRule, AlphaBetaRule
4647
export modify_input, modify_denominator
4748
export modify_param!, modify_layer!
4849
export check_model

src/lrp_rules.jl

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,62 @@ function zbox_input(in::AbstractArray{T}, A::AbstractArray) where {T}
251251
return convert.(T, A)
252252
end
253253

254+
"""
255+
AlphaBetaRule(alpha, beta)
256+
AlphaBetaRule([alpha=2.0], [beta=1.0])
257+
258+
LRP-``\alpha\beta`` rule. Weights positive and negative contributions according to the
259+
parameters `alpha` and `beta` respectively. The difference `alpha - beta` must be equal one.
260+
Commonly used on lower layers.
261+
262+
Arguments:
263+
- `alpha`: Multiplier for the positive output term, defaults to `2.0`.
264+
- `beta`: Multiplier for the negative output term, defaults to `1.0`.
265+
266+
# References
267+
[1]: S. Bach et al., On Pixel-Wise Explanations for Non-Linear Classifier Decisions by
268+
Layer-Wise Relevance Propagation
269+
[2]: G. Montavon et al., Layer-Wise Relevance Propagation: An Overview
270+
"""
271+
struct AlphaBetaRule{T} <: AbstractLRPRule
272+
α::T
273+
β::T
274+
function AlphaBetaRule(alpha=2.0f0, beta=1.0f0)
275+
alpha < 0 && throw(ArgumentError("Parameter `alpha` must be ≥0."))
276+
beta < 0 && throw(ArgumentError("Parameter `beta` must be ≥0."))
277+
!isone(alpha - beta) && throw(ArgumentError("`alpha - beta` must be equal one."))
278+
return new{Float32}(alpha, beta)
279+
end
280+
end
281+
282+
# The AlphaBetaRule requires its own implementation of relevance propagation.
283+
function lrp!(Rₖ, rule::AlphaBetaRule, layer::L, aₖ, Rₖ₊₁) where {L}
284+
require_weight_and_bias(rule, layer)
285+
reset! = get_layer_resetter(rule, layer)
286+
287+
aₖ⁺ = keep_positive(aₖ)
288+
aₖ⁻ = keep_negative(aₖ)
289+
290+
modify_layer!(Val(:keep_positive), layer)
291+
out_1, pullback_1 = Zygote.pullback(layer, aₖ⁺)
292+
reset!()
293+
modify_layer!(Val(:keep_negative_zero_bias), layer)
294+
out_2, pullback_2 = Zygote.pullback(layer, aₖ⁻)
295+
reset!()
296+
modify_layer!(Val(:keep_negative), layer)
297+
out_3, pullback_3 = Zygote.pullback(layer, aₖ⁺)
298+
reset!()
299+
modify_layer!(Val(:keep_positive_zero_bias), layer)
300+
out_4, pullback_4 = Zygote.pullback(layer, aₖ⁻)
301+
reset!()
254302

303+
y_α = Rₖ₊₁ ./ modify_denominator(rule, out_1 + out_2)
304+
y_β = Rₖ₊₁ ./ modify_denominator(rule, out_3 + out_4)
305+
Rₖ .=
306+
rule.α .* (aₖ⁺ .* only(pullback_1(y_α)) + aₖ⁻ .* only(pullback_2(y_α))) .-
307+
rule.β .* (aₖ⁺ .* only(pullback_3(y_β)) + aₖ⁻ .* only(pullback_4(y_β)))
308+
return nothing
309+
end
255310

256311
# Special cases for rules that don't modify params for extra performance:
257312
for R in (ZeroRule, EpsilonRule)
1.34 KB
Binary file not shown.
1.34 KB
Binary file not shown.
1.34 KB
Binary file not shown.
941 Bytes
Binary file not shown.
941 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)