Skip to content

Commit a6161bc

Browse files
authored
Update docs with advanced LRP example (#40)
* Simple MNIST example * Advanced LRP documentation * Make `modify_layer` public again now that it's documented
1 parent 1725863 commit a6161bc

File tree

11 files changed

+320
-96
lines changed

11 files changed

+320
-96
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,5 +63,6 @@ jobs:
6363
doctest(ExplainabilityMethods)'
6464
- run: julia --project=docs docs/make.jl
6565
env:
66+
DATADEPS_ALWAYS_ACCEPT: true # for MLDatasets download
6667
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
6768
DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }}

benchmark/benchmarks.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using BenchmarkTools
22
using Flux
33
using ExplainabilityMethods
4-
import ExplainabilityMethods: _modify_layer, lrp!
4+
import ExplainabilityMethods: modify_layer, lrp!
55

66
on_CI = haskey(ENV, "GITHUB_ACTIONS")
77

@@ -43,7 +43,7 @@ struct TestWrapper{T}
4343
layer::T
4444
end
4545
(w::TestWrapper)(x) = w.layer(x)
46-
_modify_layer(r::AbstractLRPRule, w::TestWrapper) = _modify_layer(r, w.layer)
46+
modify_layer(r::AbstractLRPRule, w::TestWrapper) = modify_layer(r, w.layer)
4747
lrp!(rule::ZBoxRule, w::TestWrapper, Rₖ, aₖ, Rₖ₊₁) = lrp!(rule, w.layer, Rₖ, aₖ, Rₖ₊₁)
4848

4949
# generate input for conv layers

docs/Project.toml

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
[deps]
2-
DataAugmentation = "88a5189c-e7ff-4f85-ac6b-e6158070f02e"
2+
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
3+
ColorSchemes = "35d6a980-a343-548e-a6ea-1d62b119f2f4"
34
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
45
ExplainabilityMethods = "cd722a4f-8d55-446b-8550-a4aabc9151ab"
56
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
6-
ImageMagick = "6218d12a-5da1-5696-b52f-db25d2ecc6d1"
7-
Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0"
7+
ImageCore = "a09fc81d-aa75-5fe9-8630-4744c3626534"
88
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
9-
Metalhead = "dbeba491-748d-5e0e-a39e-b530a07fa0cc"
10-
TestImages = "5e47fb64-e119-507b-a336-dd2b206d9990"
9+
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"

docs/literate/advanced_lrp.jl

Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
# # Advanced LRP usage
2+
# One of the design goals of ExplainabilityMethods.jl is to combine ease of use with
3+
# **extensibility** for the purpose of research.
4+
#
5+
#
6+
# This example will show you how to implement custom LRP rules and register custom layers
7+
# and activation functions.
8+
#
9+
# For this purpose, we will quickly load our model from the previous section:
10+
using ExplainabilityMethods
11+
using Flux
12+
using MLDatasets
13+
using ImageCore
14+
using BSON
15+
16+
model = BSON.load("../model.bson", @__MODULE__)[:model]
17+
18+
index = 10
19+
x, y = MNIST.testdata(Float32, index)
20+
input = reshape(x, 28, 28, 1, :);
21+
22+
# ## Custom LRP rules
23+
# Let's define a rule that modifies the weights and biases of our layer on the forward pass.
24+
# The rule has to be of type `AbstractLRPRule`.
25+
struct MyGammaRule <: AbstractLRPRule end
26+
27+
# It is then possible to dispatch on the utility functions [`modify_params`](@ref) and [`modify_denominator`](@ref)
28+
# with our rule type `MyCustomLRPRule` to define custom rules without writing any boilerplate code.
29+
# to extend internal functions, import them explicitly:
30+
import ExplainabilityMethods: modify_params
31+
32+
function modify_params(::MyGammaRule, W, b)
33+
ρW = W + 0.25 * relu.(W)
34+
ρb = b + 0.25 * relu.(b)
35+
return ρW, ρb
36+
end
37+
38+
# We can directly use this rule to make an analyzer!
39+
analyzer = LRP(model, MyGammaRule())
40+
heatmap(input, analyzer)
41+
42+
# We just implemented our own version of the ``γ``-rule in 7 lines of code!
43+
# The outputs match perfectly:
44+
analyzer = LRP(model, GammaRule())
45+
heatmap(input, analyzer)
46+
47+
# If the layer doesn't use weights and biases `W` and `b`, ExplainabilityMethods provides a
48+
# lower-level variant of [`modify_params`](@ref) called [`modify_layer`](@ref).
49+
# This function is expected to take a layer and return a new, modified layer.
50+
51+
#md # !!! warning "Using `modify_layer`"
52+
#md #
53+
#md # Use of the function `modify_layer` will overwrite functionality of `modify_params`
54+
#md # for the implemented combination of rule and layer types.
55+
#md # This is due to the fact that internally, `modify_params` is called by the default
56+
#md # implementation of `modify_layer`.
57+
#md #
58+
#md # Therefore it is recommended to only extend `modify_layer` for a specific rule
59+
#md # and a specific layer type.
60+
61+
# ## Custom layers and activation functions
62+
# ### Model checks for humans
63+
# Good model checks and presets should allow novice users to apply XAI methods
64+
# in a "plug & play" manner according to best practices.
65+
#
66+
# Let's say we define a layer that doubles its input:
67+
struct MyDoublingLayer end
68+
(::MyDoublingLayer)(x) = 2 * x
69+
70+
mylayer = MyDoublingLayer()
71+
mylayer([1, 2, 3])
72+
73+
# Let's append this layer to our model:
74+
model = Chain(model..., MyDoublingLayer())
75+
76+
# Creating an LRP analyzer, e.g. `LRPZero(model)`, will throw an `ArgumentError`
77+
# and print a summary of the model check in the REPL:
78+
# ```julia-repl
79+
# ┌───┬───────────────────────┬─────────────────┬────────────┬────────────────┐
80+
# │ │ Layer │ Layer supported │ Activation │ Act. supported │
81+
# ├───┼───────────────────────┼─────────────────┼────────────┼────────────────┤
82+
# │ 1 │ flatten │ true │ — │ true │
83+
# │ 2 │ Dense(784, 100, relu) │ true │ relu │ true │
84+
# │ 3 │ Dense(100, 10) │ true │ identity │ true │
85+
# │ 4 │ MyDoublingLayer() │ false │ — │ true │
86+
# └───┴───────────────────────┴─────────────────┴────────────┴────────────────┘
87+
# Layers failed model check
88+
# ≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡
89+
#
90+
# Found unknown layers MyDoublingLayer() that are not supported by ExplainabilityMethods' LRP implementation yet.
91+
#
92+
# If you think the missing layer should be supported by default, please submit an issue (https://github.com/adrhill/ExplainabilityMethods.jl/issues).
93+
#
94+
# These model checks can be skipped at your own risk by setting the LRP-analyzer keyword argument skip_checks=true.
95+
#
96+
# [...]
97+
# ```
98+
99+
# LRP should only be used on "Deep ReLU" networks and ExplainabilityMethods doesn't
100+
# recognize `MyDoublingLayer` as a compatible layer.
101+
# By default, it will therefore return an error and a model check summary
102+
# instead of returning an incorrect explanation.
103+
#
104+
# However, if we know `MyDoublingLayer` is compatible with "Deep ReLU" networks,
105+
# we can register it to tell ExplainabilityMethods that it is ok to use.
106+
# This will be shown in the following section.
107+
108+
#md # !!! warning "Skipping model checks"
109+
#md #
110+
#md # All model checks can be skipped at the user's own risk by setting the LRP-analyzer
111+
#md # keyword argument `skip_checks=true`.
112+
113+
# ### Registering custom layers
114+
# The error in the model check will stop after registering our custom layer type
115+
# `MyDoublingLayer` as "supported" by ExplainabilityMethods.
116+
#
117+
# This is done using the function [`LRP_CONFIG.supports_layer`](@ref), which should be set to return `true`:
118+
LRP_CONFIG.supports_layer(::MyDoublingLayer) = true
119+
120+
# Now we can create and run an analyzer without getting an error:
121+
analyzer = LRPZero(model)
122+
heatmap(input, analyzer)
123+
124+
#md # !!! note "Registering functions"
125+
#md #
126+
#md # Flux's `Chains` can also contain functions, e.g. `flatten`.
127+
#md # This kind of layer can be registered as
128+
#md # ```julia
129+
#md # LRP_CONFIG.supports_layer(::typeof(mylayer)) = true
130+
#md # ```
131+
132+
# ### Registering activation functions
133+
# The mechanism for registering custom activation functions is analogous to that of custom layers:
134+
myrelu(x) = max.(0, x)
135+
model = Chain(flatten, Dense(784, 100, myrelu), Dense(100, 10))
136+
137+
# Once again, creating an LRP analyzer for this model will throw an `ArgumentError`
138+
# and display the following model check summary:
139+
# ```julia-repl
140+
# julia> analyzer = LRPZero(model3)
141+
# ┌───┬─────────────────────────┬─────────────────┬────────────┬────────────────┐
142+
# │ │ Layer │ Layer supported │ Activation │ Act. supported │
143+
# ├───┼─────────────────────────┼─────────────────┼────────────┼────────────────┤
144+
# │ 1 │ flatten │ true │ — │ true │
145+
# │ 2 │ Dense(784, 100, myrelu) │ true │ myrelu │ false │
146+
# │ 3 │ Dense(100, 10) │ true │ identity │ true │
147+
# └───┴─────────────────────────┴─────────────────┴────────────┴────────────────┘
148+
# Activations failed model check
149+
# ≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡
150+
#
151+
# Found layers with unknown or unsupported activation functions myrelu. LRP assumes that the model is a "deep rectifier network" that only contains ReLU-like activation functions.
152+
#
153+
# If you think the missing activation function should be supported by default, please submit an issue (https://github.com/adrhill/ExplainabilityMethods.jl/issues).
154+
#
155+
# These model checks can be skipped at your own risk by setting the LRP-analyzer keyword argument skip_checks=true.
156+
#
157+
# [...]
158+
# ```
159+
160+
# Registation works by defining the function [`LRP_CONFIG.supports_activation`](@ref) as `true`:
161+
LRP_CONFIG.supports_activation(::typeof(myrelu)) = true
162+
163+
# now the analyzer can be created without error:
164+
analyzer = LRPZero(model)
165+
166+
# ## How it works internally
167+
# Internally, ExplainabilityMethods dispatches to low level functions
168+
# ```julia
169+
# lrp!(rule, layer, Rₖ, aₖ, Rₖ₊₁)
170+
# ```
171+
# These functions dispatch on rule and layer type and inplace-modify pre-allocated arrays `Rₖ`
172+
# based on the inputs `aₖ` and `Rₖ₊₁`.
173+
#
174+
# The default LRP fallback for unknown layers uses automatic differentiation (AD) via Zygote:
175+
# ```julia
176+
# function lrp!(rule, layer, Rₖ, aₖ, Rₖ₊₁)
177+
# layerᵨ = modify_layer(rule, layer)
178+
# c = gradient(aₖ) do a
179+
# z = layerᵨ(a)
180+
# s = Zygote.@ignore Rₖ₊₁ ./ modify_denominator(rule, z)
181+
# z ⋅ s
182+
# end |> only
183+
# Rₖ .= aₖ .* c
184+
# end
185+
# ```
186+
#
187+
# Here you can clearly see how this AD-fallback dispatches on `modify_layer` and `modify_denominator`
188+
# based on the rule and layer type. This is how we implemented our own `MyGammaRule`!
189+
# Unknown layers that are registered in the `LRP_CONFIG` use this exact function.
190+
#
191+
# We can also implement versions of `lrp!` that are specialized for specific layer type.
192+
# For example, reshaping layers don't affect attributions, therefore no AD is required.
193+
# ExplainabilityMethods implements:
194+
# ```julia
195+
# function lrp!(rule, ::ReshapingLayer, Rₖ, aₖ, Rₖ₊₁)
196+
# Rₖ .= reshape(Rₖ₊₁, size(aₖ))
197+
# end
198+
# ```
199+
#
200+
# Even Dense layers have a specialized implementation:
201+
# ```julia
202+
# function lrp!(rule, layer::Dense, Rₖ, aₖ, Rₖ₊₁)
203+
# ρW, ρb = modify_params(rule, get_params(layer)...)
204+
# ãₖ₊₁ = modify_denominator(rule, ρW * aₖ + ρb)
205+
# @tullio Rₖ[j] = aₖ[j] * ρW[k, j] / ãₖ₊₁[k] * Rₖ₊₁[k] # Tullio = fast einsum
206+
# end
207+
# ```
208+
# Just like in the LRP papers!
209+
#
210+
# For maximum low-level control, you can also implement your own `lrp!` function
211+
# and dispatch on individual rule types `MyRule` and layer types `MyLayer`:
212+
# ```julia
213+
# function lrp!(rule::MyRule, layer::MyLayer, Rₖ, aₖ, Rₖ₊₁)
214+
# Rₖ .= ...
215+
# end
216+
# ```

0 commit comments

Comments
 (0)