Skip to content

Commit 2fbc117

Browse files
authored
Update docs with LRP equations (#41)
1 parent a6161bc commit 2fbc117

File tree

2 files changed

+75
-27
lines changed

2 files changed

+75
-27
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ model = strip_softmax(vgg.layers)
2828

2929
# Run XAI method
3030
analyzer = LRPEpsilon(model)
31-
expl, out = analyze(img, analyzer)
31+
expl = analyze(img, analyzer)
3232

3333
# Show heatmap
3434
heatmap(expl)

docs/literate/advanced_lrp.jl

Lines changed: 74 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
# # Advanced LRP usage
2-
# One of the design goals of ExplainabilityMethods.jl is to combine ease of use with
3-
# **extensibility** for the purpose of research.
2+
# One of the design goals of ExplainabilityMethods.jl is to combine ease of use and
3+
# extensibility for the purpose of research.
44
#
55
#
66
# This example will show you how to implement custom LRP rules and register custom layers
77
# and activation functions.
8-
#
98
# For this purpose, we will quickly load our model from the previous section:
109
using ExplainabilityMethods
1110
using Flux
@@ -25,8 +24,8 @@ input = reshape(x, 28, 28, 1, :);
2524
struct MyGammaRule <: AbstractLRPRule end
2625

2726
# It is then possible to dispatch on the utility functions [`modify_params`](@ref) and [`modify_denominator`](@ref)
28-
# with our rule type `MyCustomLRPRule` to define custom rules without writing any boilerplate code.
29-
# to extend internal functions, import them explicitly:
27+
# with the rule type `MyCustomLRPRule` to define custom rules without writing any boilerplate code.
28+
# To extend internal functions, import them explicitly:
3029
import ExplainabilityMethods: modify_params
3130

3231
function modify_params(::MyGammaRule, W, b)
@@ -48,7 +47,7 @@ heatmap(input, analyzer)
4847
# lower-level variant of [`modify_params`](@ref) called [`modify_layer`](@ref).
4948
# This function is expected to take a layer and return a new, modified layer.
5049

51-
#md # !!! warning "Using `modify_layer`"
50+
#md # !!! warning "Using modify_layer"
5251
#md #
5352
#md # Use of the function `modify_layer` will overwrite functionality of `modify_params`
5453
#md # for the implemented combination of rule and layer types.
@@ -96,12 +95,12 @@ model = Chain(model..., MyDoublingLayer())
9695
# [...]
9796
# ```
9897

99-
# LRP should only be used on "Deep ReLU" networks and ExplainabilityMethods doesn't
98+
# LRP should only be used on deep rectifier networks and ExplainabilityMethods doesn't
10099
# recognize `MyDoublingLayer` as a compatible layer.
101100
# By default, it will therefore return an error and a model check summary
102101
# instead of returning an incorrect explanation.
103102
#
104-
# However, if we know `MyDoublingLayer` is compatible with "Deep ReLU" networks,
103+
# However, if we know `MyDoublingLayer` is compatible with deep rectifier networks,
105104
# we can register it to tell ExplainabilityMethods that it is ok to use.
106105
# This will be shown in the following section.
107106

@@ -114,7 +113,8 @@ model = Chain(model..., MyDoublingLayer())
114113
# The error in the model check will stop after registering our custom layer type
115114
# `MyDoublingLayer` as "supported" by ExplainabilityMethods.
116115
#
117-
# This is done using the function [`LRP_CONFIG.supports_layer`](@ref), which should be set to return `true`:
116+
# This is done using the function [`LRP_CONFIG.supports_layer`](@ref),
117+
# which should be set to return `true` for the type `MyDoublingLayer`:
118118
LRP_CONFIG.supports_layer(::MyDoublingLayer) = true
119119

120120
# Now we can create and run an analyzer without getting an error:
@@ -166,49 +166,97 @@ analyzer = LRPZero(model)
166166
# ## How it works internally
167167
# Internally, ExplainabilityMethods dispatches to low level functions
168168
# ```julia
169-
# lrp!(rule, layer, Rₖ, aₖ, Rₖ₊₁)
169+
# function lrp!(rule, layer, Rₖ, aₖ, Rₖ₊₁)
170+
# Rₖ .= ...
171+
# end
170172
# ```
171-
# These functions dispatch on rule and layer type and inplace-modify pre-allocated arrays `Rₖ`
172-
# based on the inputs `aₖ` and `Rₖ₊₁`.
173+
# These functions use the arguments `rule` and `layer` to dispatch
174+
# `modify_params` and `modify_denominator` on the rule and layer type.
175+
# They in-place modify a pre-allocated array of the input relevance `Rₖ`
176+
# based on the input activation `aₖ` and output relevance `Rₖ₊₁`.
173177
#
174-
# The default LRP fallback for unknown layers uses automatic differentiation (AD) via Zygote:
178+
# Calling `analyze` then applies a foward-pass of the model, keeping track of
179+
# the activations `aₖ` for each layer `k`.
180+
# The relevance `Rₖ₊₁` is then set to the output neuron activation and the rules are applied
181+
# in a backward-pass over the model layers and previous activations.
182+
183+
# ### Generic rule implementation using automatic differentiation
184+
# The generic LRP rule–of which the ``0``-, ``\epsilon``- and ``\gamma``-rules are special cases–reads[^1][^2]:
185+
# ```math
186+
# R_{j}=\sum_{k} \frac{a_{j} \cdot \rho\left(w_{j k}\right)}{\epsilon+\sum_{0, j} a_{j} \cdot \rho\left(w_{j k}\right)} R_{k}
187+
# ```
188+
#
189+
# where ``\rho`` is a function that modifies parameters – what we have so far called `modify_params`.
190+
#
191+
# The computation of this propagation rule can be decomposed into four steps:
192+
# ```math
193+
# \begin{array}{lr}
194+
# \forall_{k}: z_{k}=\epsilon+\sum_{0, j} a_{j} \cdot \rho\left(w_{j k}\right) & \text { (forward pass) } \\
195+
# \forall_{k}: s_{k}=R_{k} / z_{k} & \text { (element-wise division) } \\
196+
# \forall_{j}: c_{j}=\sum_{k} \rho\left(w_{j k}\right) \cdot s_{k} & \text { (backward pass) } \\
197+
# \forall_{j}: R_{j}=a_{j} c_{j} & \text { (element-wise product) }
198+
# \end{array}
199+
# ```
200+
#
201+
# For deep rectifier networks, the third step can also be written as the gradient computation
202+
# ```math
203+
# c_{j}=\left[\nabla\left(\sum_{k} z_{k}(\boldsymbol{a}) \cdot s_{k}\right)\right]_{j}
204+
# ```
205+
#
206+
# and can be implemented via automatic differentiation (AD).
207+
#
208+
# This equation is implemented in ExplainabilityMethods as the default method
209+
# for all layer types that don't have a specialized implementation.
210+
# We will refer to it as the "AD fallback".
211+
#
212+
# [^1]: G. Montavon et al., [Layer-Wise Relevance Propagation: An Overview](https://link.springer.com/chapter/10.1007/978-3-030-28954-6_10)
213+
# [^2]: W. Samek et al., [Explaining Deep Neural Networks and Beyond: A Review of Methods and Applications](https://ieeexplore.ieee.org/document/9369420)
214+
215+
# ### AD fallback
216+
# The default LRP fallback for unknown layers uses AD via [Zygote](https://github.com/FluxML/Zygote.jl).
217+
# For `lrp!`, we end up with something that looks very similar to the previous four step computation:
175218
# ```julia
176219
# function lrp!(rule, layer, Rₖ, aₖ, Rₖ₊₁)
177220
# layerᵨ = modify_layer(rule, layer)
178221
# c = gradient(aₖ) do a
179222
# z = layerᵨ(a)
180223
# s = Zygote.@ignore Rₖ₊₁ ./ modify_denominator(rule, z)
181224
# z ⋅ s
182-
# end |> only
225+
# end |> only
183226
# Rₖ .= aₖ .* c
184227
# end
185228
# ```
186229
#
187-
# Here you can clearly see how this AD-fallback dispatches on `modify_layer` and `modify_denominator`
188-
# based on the rule and layer type. This is how we implemented our own `MyGammaRule`!
230+
# You can see how `modify_layer` and `modify_denominator` dispatch on the rule and layer type.
231+
# This is how we implemented our own `MyGammaRule`.
189232
# Unknown layers that are registered in the `LRP_CONFIG` use this exact function.
233+
234+
# ### Specialized implementations
235+
# We can also implement specialized versions of `lrp!` based on the type of `layer`,
236+
# e.g. reshaping layers.
190237
#
191-
# We can also implement versions of `lrp!` that are specialized for specific layer type.
192-
# For example, reshaping layers don't affect attributions, therefore no AD is required.
193-
# ExplainabilityMethods implements:
238+
# Reshaping layers don't affect attributions. We can therefore avoid the computational
239+
# overhead of AD by writing a specialized implementation that simply reshapes back:
194240
# ```julia
195-
# function lrp!(rule, ::ReshapingLayer, Rₖ, aₖ, Rₖ₊₁)
241+
# function lrp!(::AbstractLRPRule, ::ReshapingLayer, Rₖ, aₖ, Rₖ₊₁)
196242
# Rₖ .= reshape(Rₖ₊₁, size(aₖ))
197243
# end
198244
# ```
199245
#
200-
# Even Dense layers have a specialized implementation:
246+
# Since the rule type didn't matter in this case, we didn't specify it.
247+
#
248+
# We can even implement the generic rule as a specialized implementation for `Dense` layers:
201249
# ```julia
202-
# function lrp!(rule, layer::Dense, Rₖ, aₖ, Rₖ₊₁)
250+
# function lrp!(rule::AbstractLRPRule, layer::Dense, Rₖ, aₖ, Rₖ₊₁)
203251
# ρW, ρb = modify_params(rule, get_params(layer)...)
204252
# ãₖ₊₁ = modify_denominator(rule, ρW * aₖ + ρb)
205-
# @tullio Rₖ[j] = aₖ[j] * ρW[k, j] / ãₖ₊₁[k] * Rₖ₊₁[k] # Tullio = fast einsum
253+
# @tullio Rₖ[j] = aₖ[j] * ρW[k, j] / ãₖ₊₁[k] * Rₖ₊₁[k] # Tullio fast einsum
206254
# end
207255
# ```
208-
# Just like in the LRP papers!
209256
#
210-
# For maximum low-level control, you can also implement your own `lrp!` function
211-
# and dispatch on individual rule types `MyRule` and layer types `MyLayer`:
257+
# For maximum low-level control beyond `modify_layer`, `modify_params` and `modify_denominator`,
258+
# you can also implement your own `lrp!` function and dispatch
259+
# on individual rule types `MyRule` and layer types `MyLayer`:
212260
# ```julia
213261
# function lrp!(rule::MyRule, layer::MyLayer, Rₖ, aₖ, Rₖ₊₁)
214262
# Rₖ .= ...

0 commit comments

Comments
 (0)