Skip to content

Commit 367b59b

Browse files
authored
Add LRP composites (#84)
* Add Composite and primitives * Add pretty-printing of composites * Add tests and docstrings * Add "Advanced LRP" docs example * Update simple example
1 parent 5afd260 commit 367b59b

File tree

11 files changed

+459
-18
lines changed

11 files changed

+459
-18
lines changed

docs/literate/advanced_lrp.jl

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
#
66
# This example will show you how to implement custom LRP rules and register custom layers
77
# and activation functions.
8-
# For this purpose, we will quickly load our model from the previous section:
8+
# For this purpose, we will quickly load our model from the previous section
99
using ExplainableAI
1010
using Flux
1111
using MLDatasets
@@ -14,14 +14,14 @@ using BSON
1414

1515
model = BSON.load("../model.bson", @__MODULE__)[:model]
1616

17+
# and data from the MNIST dataset
1718
index = 10
1819
x, _ = MNIST(Float32, :test)[10]
1920
input = reshape(x, 28, 28, 1, :);
2021

2122
# ## Custom LRP composites
22-
# Instead of creating an LRP-analyzer from a single rule (e.g. `LRP(model, GammaRule())`),
23-
# we can also assign rules to each layer individually.
24-
# For this purpose, we create an array of rules that matches the length of the Flux chain:
23+
# When creating an LRP-analyzer, we can assign individual rules to each layer.
24+
# The array of rules has to match the length of the Flux chain:
2525
rules = [
2626
ZBoxRule(0.0f0, 1.0f0),
2727
EpsilonRule(),
@@ -43,6 +43,44 @@ heatmap(input, analyzer)
4343
#md # Not all models can be flattened, e.g. those using
4444
#md # `Parallel` and `SkipConnection` layers.
4545

46+
# Instead of manually defining a list of rules, we can also use a [`Composite`](@ref).
47+
# A composite contructs a list of LRP-rules by sequentially applying composite primitives.
48+
#
49+
# To obtain the same set of rules as in the previous example, we can define
50+
composite = Composite(
51+
ZeroRule(), # default rule
52+
GlobalRuleMap(
53+
Conv => GammaRule(), # apply GammaRule on all convolutional layers
54+
MaxPool => EpsilonRule(), # apply EpsilonRule on all pooling-layers
55+
),
56+
FirstRule(ZBoxRule(0.0f0, 1.0f0)), # apply ZBoxRule on the first layer
57+
)
58+
59+
analyzer = LRP(model, composite) # construct LRP analyzer from composite
60+
heatmap(input, analyzer)
61+
62+
# This analyzer contains the same rules as our previous one:
63+
analyzer.rules # show rules
64+
65+
# ### Composite primitives
66+
# The following sets of primitives can used to construct a [`Composite`](@ref).
67+
#
68+
# To apply a single rule, use:
69+
# * [`LayerRule`](@ref) to apply a rule to the `n`-th layer of a model
70+
# * [`GlobalRule`](@ref) to apply a rule to all layers of a model
71+
# * [`RangeRule`](@ref) to apply a rule to a positional range of layers of a model
72+
# * [`FirstRule`](@ref) to apply a rule to the first layer of a model
73+
# * [`LastRule`](@ref) to apply a rule to the last layer of a model
74+
#
75+
# To apply a set of rules to multiple layers, use:
76+
# * [`GlobalRuleMap`](@ref) to apply a dictionary that maps layer types to LRP-rules
77+
# * [`RangeRuleMap`](@ref) for a `RuleMap` on generalized ranges
78+
# * [`FirstNRuleMap`](@ref) for a `RuleMap` on the first `n` layers of a model
79+
# * [`LastNRuleMap`](@ref) for a `RuleMap` on the last `n` layers
80+
#
81+
# Primitives are called sequentially in the order the `Composite` was created with
82+
# and overwrite rules specified by previous primitives.
83+
4684
# ## Custom LRP rules
4785
# Let's define a rule that modifies the weights and biases of our layer on the forward pass.
4886
# The rule has to be of type `AbstractLRPRule`.

docs/literate/example.jl

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ model = BSON.load("../model.bson", @__MODULE__)[:model]
2020
#md # !!! warning "Strip softmax"
2121
#md #
2222
#md # For models with softmax activations on the output, it is necessary to call
23+
#md # [`strip_softmax`](@ref)
2324
#md # ```julia
2425
#md # model = strip_softmax(model)
2526
#md # ```
@@ -59,14 +60,14 @@ expl = analyze(input, analyzer);
5960
# * `expl.neuron_selection`: the neuron index of used for the attribution
6061
# * `expl.analyzer`: a symbol corresponding the used analyzer, e.g. `:LRP`
6162

62-
# Finally, we can visualize the `Explanation` through heatmapping:
63+
# Finally, we can visualize the `Explanation` through [`heatmap`](@ref):
6364
heatmap(expl)
6465

6566
# Or get the same result by combining both analysis and heatmapping into one step:
6667
heatmap(input, analyzer)
6768

6869
# ## Neuron selection
69-
# By passing an additional index to our call to `analyze`, we can compute the attribution
70+
# By passing an additional index to our call to [`analyze`](@ref), we can compute the attribution
7071
# with respect to a specific output neuron.
7172
# Let's see why the output wasn't interpreted as a 4 (output neuron at index 5)
7273
heatmap(input, analyzer, 5)
@@ -76,7 +77,7 @@ heatmap(input, analyzer, 5)
7677

7778
#md # !!! note
7879
#md #
79-
#md # The ouput neuron can also be specified when calling `analyze`:
80+
#md # The ouput neuron can also be specified when calling [`analyze`](@ref):
8081
#md # ```julia
8182
#md # expl = analyze(img, analyzer, 5)
8283
#md # ```
@@ -110,26 +111,31 @@ mosaic(heatmap(batch, analyzer, 1); nrow=10)
110111
# ├── SmoothGrad
111112
# ├── IntegratedGradients
112113
# └── LRP
113-
# ├── LRPZero
114-
# ├── LRPEpsilon
115-
# └── LRPGamma
114+
# ├── ZeroRule
115+
# ├── EpsilonRule
116+
# ├── GammaRule
117+
# ├── WSquareRule
118+
# ├── FlatRule
119+
# ├── ZBoxRule
120+
# ├── AlphaBetaRule
121+
# └── PassRule
116122
# ```
117123
#
118-
# Let's try `InputTimesGradient`
124+
# Let's try [`InputTimesGradient`](@ref)
119125
analyzer = InputTimesGradient(model)
120126
heatmap(input, analyzer)
121127

122-
# and `Gradient`
128+
# and [`Gradient`](@ref)
123129
analyzer = Gradient(model)
124130
heatmap(input, analyzer)
125131

126-
# As you can see, the function `heatmap` automatically applies common presets for each method.
132+
# As you can see, the function [`heatmap`](@ref) automatically applies common presets for each method.
127133
#
128-
# Since `InputTimesGradient` and LRP both compute attributions, their presets are similar.
134+
# Since [`InputTimesGradient`](@ref) and [`LRP`](@ref) both compute attributions, their presets are similar.
129135
# Gradient methods however are typically shown in grayscale.
130136

131137
# ## Custom heatmap settings
132-
# We can partially or fully override presets by passing keyword arguments to `heatmap`:
138+
# We can partially or fully override presets by passing keyword arguments to [`heatmap`](@ref):
133139
using ColorSchemes
134140
heatmap(expl; cs=ColorSchemes.jet)
135141
#

docs/src/api.md

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,28 @@ LRP_CONFIG.supports_layer
4545
LRP_CONFIG.supports_activation
4646
```
4747

48+
## Composites
49+
```@docs
50+
Composite
51+
```
52+
53+
Composite primitives that apply a single rule:
54+
```@docs
55+
LayerRule
56+
GlobalRule
57+
RangeRule
58+
FirstRule
59+
LastRule
60+
```
61+
62+
Composite primitives that apply a set of rules to multiple layers:
63+
```@docs
64+
GlobalRuleMap
65+
RangeRuleMap
66+
FirstNRuleMap
67+
LastNRuleMap
68+
```
69+
4870
# Utilities
4971
```@docs
5072
strip_softmax

src/ExplainableAI.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,17 @@ using PrettyTables
2020
include("compat.jl")
2121
include("neuron_selection.jl")
2222
include("analyze_api.jl")
23-
include("types.jl")
23+
include("flux_types.jl")
2424
include("flux_utils.jl")
2525
include("utils.jl")
2626
include("input_augmentation.jl")
2727
include("gradient.jl")
2828
include("lrp/canonize.jl")
2929
include("lrp/checks.jl")
3030
include("lrp/rules.jl")
31+
include("lrp/composite.jl")
3132
include("lrp/lrp.jl")
33+
include("lrp/show.jl")
3234
include("heatmap.jl")
3335
include("preprocessing.jl")
3436
export analyze
@@ -49,6 +51,12 @@ export modify_input, modify_denominator
4951
export modify_param!, modify_layer!
5052
export check_model
5153

54+
# LRP composites
55+
export Composite, AbstractCompositePrimitive
56+
export LayerRule, GlobalRule, RangeRule, FirstRule, LastRule
57+
export GlobalRuleMap, RangeRuleMap, FirstNRuleMap, LastNRuleMap
58+
export ConvLayer, PoolingLayer, DropoutLayer, ReshapingLayer
59+
5260
# heatmapping
5361
export heatmap
5462

File renamed without changes.

0 commit comments

Comments
 (0)