Skip to content

Commit bc06035

Browse files
authored
Implement Base.show for LRP analyzers (#89)
* Implement `Base.show` for LRP analyzers * Add tests * Update "Advanced LRP" docs
1 parent 0d54da7 commit bc06035

File tree

5 files changed

+81
-20
lines changed

5 files changed

+81
-20
lines changed

docs/literate/advanced_lrp.jl

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,10 @@ x, _ = MNIST(Float32, :test)[10]
2020
input = reshape(x, 28, 28, 1, :);
2121

2222
# ## LRP composites
23-
# ### Custom composites
23+
# ### Assigning individual rules
2424
# When creating an LRP-analyzer, we can assign individual rules to each layer.
25-
# The array of rules has to match the length of the Flux chain:
25+
# The array of rules has to match the length of the Flux chain.
26+
# The `LRP` analyzer will show a summary of how layers and rules got matched:
2627
rules = [
2728
ZBoxRule(0.0f0, 1.0f0),
2829
EpsilonRule(),
@@ -35,6 +36,8 @@ rules = [
3536
]
3637

3738
analyzer = LRP(model, rules)
39+
40+
#
3841
heatmap(input, analyzer)
3942

4043
# Since some Flux Chains contain other Flux Chains, ExplainableAI provides
@@ -44,8 +47,10 @@ heatmap(input, analyzer)
4447
#md # Not all models can be flattened, e.g. those using
4548
#md # `Parallel` and `SkipConnection` layers.
4649

50+
# ### Custom composites
4751
# Instead of manually defining a list of rules, we can also use a [`Composite`](@ref).
48-
# A composite contructs a list of LRP-rules by sequentially applying composite primitives.
52+
# A composite contructs a list of LRP-rules by sequentially applying
53+
# [Composite primitives](@ref composite_primitive_api) it contains.
4954
#
5055
# To obtain the same set of rules as in the previous example, we can define
5156
composite = Composite(
@@ -57,11 +62,12 @@ composite = Composite(
5762
FirstLayerRule(ZBoxRule(0.0f0, 1.0f0)), # apply ZBoxRule on the first layer
5863
)
5964

60-
analyzer = LRP(model, composite) # construct LRP analyzer from composite
61-
heatmap(input, analyzer)
65+
# We now construct an LRP analyzer from `composite`
66+
analyzer = LRP(model, composite)
6267

63-
# This analyzer contains the same rules as our previous one:
64-
analyzer.rules # show rules
68+
# As you can see, this analyzer contains the same rules as our previous one
69+
# and therefore also produces the same heatmaps:
70+
heatmap(input, analyzer)
6571

6672
# ### Composite primitives
6773
# The following [Composite primitives](@ref composite_primitive_api) can used to construct a [`Composite`](@ref).
@@ -87,7 +93,9 @@ analyzer.rules # show rules
8793
# ### Default composites
8894
# A list of implemented default composites can be found under
8995
# [Default composites](@ref default_composite_api) in the API reference, e.g. [`EpsilonPlusFlat`](@ref):
90-
EpsilonPlusFlat()
96+
composite = EpsilonPlusFlat()
97+
#
98+
analyzer = LRP(model, composite)
9199

92100
# ## Custom LRP rules
93101
# Let's define a rule that modifies the weights and biases of our layer on the forward pass.
@@ -101,7 +109,7 @@ struct MyGammaRule <: AbstractLRPRule end
101109
import ExplainableAI: modify_param!
102110

103111
function modify_param!(::MyGammaRule, param)
104-
param .+= 0.25 * relu.(param)
112+
param .+= 0.25f0 * relu.(param)
105113
return nothing
106114
end
107115

src/lrp/show.jl

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,33 @@
1-
# Composites
1+
const COLOR_COMMENT = :light_black
2+
const COLOR_RULE = :yellow
3+
const COLOR_TYPE = :blue
4+
const COLOR_RANGE = :green
5+
6+
typename(x) = string(nameof(typeof(x)))
7+
8+
################
9+
# LRP analyzer #
10+
################
11+
12+
_print_layer(io::IO, l) = string(sprint(show, l; context=io))
13+
function Base.show(io::IO, m::MIME"text/plain", analyzer::LRP)
14+
layer_names = [_print_layer(io, l) for l in analyzer.model]
15+
npad = maximum(length.(layer_names)) + 1 # padding to align rules with rpad
16+
17+
println(io, "LRP", "(")
18+
for (l, r) in zip(layer_names, analyzer.rules)
19+
print(io, " ", rpad(l, npad), " => ")
20+
printstyled(io, r; color=COLOR_RULE)
21+
println(io, ",")
22+
end
23+
println(io, ")")
24+
return nothing
25+
end
26+
27+
#############
28+
# Composite #
29+
#############
30+
231
_range_string(r::LayerRule) = "layer $(r.n)"
332
_range_string(::GlobalRule) = "all layers"
433
_range_string(r::RangeRule) = "layers $(r.range)"
@@ -11,13 +40,6 @@ _range_string(::LastLayerTypeRule) = "last layer"
1140
_range_string(r::FirstNTypeRule) = "layers $(1:r.n)"
1241
_range_string(r::LastNTypeRule) = "last $(r.n) layers"
1342

14-
const COLOR_COMMENT = :light_black
15-
const COLOR_RULE = :yellow
16-
const COLOR_TYPE = :blue
17-
const COLOR_RANGE = :green
18-
19-
typename(x) = string(nameof(typeof(x)))
20-
2143
function Base.show(io::IO, m::MIME"text/plain", c::Composite)
2244
println(io, "Composite", "(")
2345
for p in c.primitives

test/references/show/lrp1.txt

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
LRP(
2+
Conv((3, 3), 3 => 64, relu, pad=1) => ZBoxRule{Float32}(-3.0f0, 3.0f0),
3+
MaxPool((2, 2)) => EpsilonRule{Float32}(1.0f-6),
4+
Conv((3, 3), 64 => 128, relu, pad=1) => FlatRule(),
5+
MaxPool((2, 2)) => EpsilonRule{Float32}(1.0f-5),
6+
Conv((3, 3), 128 => 256, relu, pad=1) => FlatRule(),
7+
Conv((3, 3), 256 => 256, relu, pad=1) => FlatRule(),
8+
MaxPool((2, 2)) => EpsilonRule{Float32}(1.0f-5),
9+
Conv((3, 3), 256 => 512, relu, pad=1) => AlphaBetaRule{Float32}(2.0f0, 1.0f0),
10+
Conv((3, 3), 512 => 512, relu, pad=1) => AlphaBetaRule{Float32}(1.0f0, 0.0f0),
11+
MaxPool((2, 2)) => EpsilonRule{Float32}(1.0f-5),
12+
Conv((3, 3), 512 => 512, relu, pad=1) => AlphaBetaRule{Float32}(2.0f0, 1.0f0),
13+
Conv((3, 3), 512 => 512, relu, pad=1) => AlphaBetaRule{Float32}(2.0f0, 1.0f0),
14+
MaxPool((2, 2)) => EpsilonRule{Float32}(1.0f-6),
15+
Flux.flatten => PassRule(),
16+
Dense(25088 => 4096, relu) => EpsilonRule{Float32}(1.0f-6),
17+
Dropout(0.5) => PassRule(),
18+
Dense(4096 => 4096, relu) => EpsilonRule{Float32}(1.0f-7),
19+
Dropout(0.5) => ZeroRule(),
20+
Dense(4096 => 1000) => PassRule(),
21+
)

test/references/show/lrp2.txt

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
LRP(
2+
Conv((5, 5), 1 => 6, relu) => AlphaBetaRule{Float32}(2.0f0, 1.0f0),
3+
MaxPool((2, 2)) => ZeroRule(),
4+
Conv((5, 5), 6 => 16, relu) => ZeroRule(),
5+
MaxPool((2, 2)) => ZeroRule(),
6+
Flux.flatten => ZeroRule(),
7+
Dense(256 => 120, relu) => ZeroRule(),
8+
Dense(120 => 84, relu) => ZeroRule(),
9+
Dense(84 => 10) => EpsilonRule{Float32}(2.0f-5),
10+
)

test/test_composite.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,7 @@ for (name, c) in DEFAULT_COMPOSITES
2020
@test_reference "references/show/$name.txt" repr("text/plain", c)
2121
end
2222

23-
# Test printing
24-
25-
# This composite is non-sensical, but covers as many composite primitives as possible
23+
# This composite is non-sensical, but covers many composite primitives
2624
composite1 = Composite(
2725
ZeroRule(), # default rule
2826
GlobalRule(PassRule()), # override default rule
@@ -62,6 +60,7 @@ analyzer1 = LRP(model, composite1)
6260
ZeroRule()
6361
PassRule()
6462
]
63+
@test_reference "references/show/lrp1.txt" repr("text/plain", analyzer1)
6564
@test_reference "references/show/composite1.txt" repr("text/plain", composite1)
6665

6766
model = Chain(
@@ -91,4 +90,5 @@ analyzer2 = LRP(model, composite2)
9190
ZeroRule()
9291
EpsilonRule(2.0f-5)
9392
]
93+
@test_reference "references/show/lrp2.txt" repr("text/plain", analyzer2)
9494
@test_reference "references/show/composite2.txt" repr("text/plain", composite2)

0 commit comments

Comments
 (0)