@@ -48,38 +48,39 @@ heatmap(input, analyzer)
48
48
# The rule has to be of type `AbstractLRPRule`.
49
49
struct MyGammaRule <: AbstractLRPRule end
50
50
51
- # It is then possible to dispatch on the utility functions [`modify_params`](@ref) and [`modify_denominator`](@ref)
52
- # with the rule type `MyCustomLRPRule` to define custom rules without writing any boilerplate code.
51
+ # It is then possible to dispatch on the utility functions [`modify_input`](@ref),
52
+ # [`modify_param!`](@ref) and [`modify_denominator`](@ref) with the rule type
53
+ # `MyCustomLRPRule` to define custom rules without writing any boilerplate code.
53
54
# To extend internal functions, import them explicitly:
54
- import ExplainableAI: modify_params
55
+ import ExplainableAI: modify_param!
55
56
56
- function modify_params (:: MyGammaRule , W, b)
57
- ρW = W + 0.25 * relu .(W)
58
- ρb = b + 0.25 * relu .(b)
59
- return ρW, ρb
57
+ function modify_param! (:: MyGammaRule , param)
58
+ param .+ = 0.25 * relu .(param)
59
+ return nothing
60
60
end
61
61
62
62
# We can directly use this rule to make an analyzer!
63
63
analyzer = LRP (model, MyGammaRule ())
64
64
heatmap (input, analyzer)
65
65
66
- # We just implemented our own version of the ``γ``-rule in 7 lines of code!
66
+ # We just implemented our own version of the ``γ``-rule in 4 lines of code!
67
67
# The outputs match perfectly:
68
68
analyzer = LRP (model, GammaRule ())
69
69
heatmap (input, analyzer)
70
70
71
- # If the layer doesn't use weights and biases `W` and `b`, ExplainableAI provides a
72
- # lower-level variant of [`modify_params`](@ref) called [`modify_layer`](@ref).
73
- # This function is expected to take a layer and return a new, modified layer.
71
+ # If the layer doesn't use weights `layer.weight` and biases `layer.bias`,
72
+ # ExplainableAI provides a lower-level variant of [`modify_param!`](@ref)
73
+ # called [`modify_layer!`](@ref). This function is expected to take a layer
74
+ # and return a new, modified layer.
74
75
75
- # md # !!! warning "Using modify_layer"
76
+ # md # !!! warning "Using modify_layer! "
76
77
# md #
77
- # md # Use of the function `modify_layer` will overwrite functionality of `modify_params `
78
+ # md # Use of the function `modify_layer! ` will overwrite functionality of `modify_param! `
78
79
# md # for the implemented combination of rule and layer types.
79
- # md # This is due to the fact that internally, `modify_params ` is called by the default
80
- # md # implementation of `modify_layer`.
80
+ # md # This is due to the fact that internally, `modify_param! ` is called by the default
81
+ # md # implementation of `modify_layer! `.
81
82
# md #
82
- # md # Therefore it is recommended to only extend `modify_layer` for a specific rule
83
+ # md # Therefore it is recommended to only extend `modify_layer! ` for a specific rule
83
84
# md # and a specific layer type.
84
85
85
86
# ## Custom layers and activation functions
@@ -202,7 +203,7 @@ analyzer = LRPZero(model)
202
203
# The correct rule is applied via [multiple dispatch](https://www.youtube.com/watch?v=kc9HwsxE1OY)
203
204
# on the types of the arguments `rule` and `layer`.
204
205
# The relevance `Rₖ` is then computed based on the input activation `aₖ` and the output relevance `Rₖ₊₁`.
205
- # Multiple dispatch is also used to dispatch `modify_params ` and `modify_denominator` on the rule and layer type.
206
+ # Multiple dispatch is also used to dispatch `modify_param! ` and `modify_denominator` on the rule and layer type.
206
207
#
207
208
# Calling `analyze` on a LRP-model applies a forward-pass of the model, keeping track of
208
209
# the activations `aₖ` for each layer `k`.
@@ -215,7 +216,7 @@ analyzer = LRPZero(model)
215
216
# R_{j}=\sum_{k} \frac{a_{j} \cdot \rho\left(w_{j k}\right)}{\epsilon+\sum_{0, j} a_{j} \cdot \rho\left(w_{j k}\right)} R_{k}
216
217
# ```
217
218
#
218
- # where ``\rho`` is a function that modifies parameters – what we have so far called `modify_params `.
219
+ # where ``\rho`` is a function that modifies parameters – what we call `modify_param! `.
219
220
#
220
221
# The computation of this propagation rule can be decomposed into four steps:
221
222
# ```math
@@ -243,21 +244,20 @@ analyzer = LRPZero(model)
243
244
244
245
# ### AD fallback
245
246
# The default LRP fallback for unknown layers uses AD via [Zygote](https://github.com/FluxML/Zygote.jl).
246
- # For `lrp!`, we end up with something that looks very similar to the previous four step computation:
247
+ # For `lrp!`, we implement the previous four step computation using `Zygote.pullback` to
248
+ # compute ``c`` from the previous equation as a VJP, pulling back ``s_{k}=R_{k}/z_{k}``:
247
249
# ```julia
248
250
# function lrp!(Rₖ, rule, layer, aₖ, Rₖ₊₁)
249
- # layerᵨ = modify_layer(rule, layer)
250
- # c = gradient(aₖ) do a
251
- # z = layerᵨ(a)
252
- # s = Zygote.@ignore Rₖ₊₁ ./ modify_denominator(rule, z)
253
- # z ⋅ s
254
- # end |> only
255
- # Rₖ .= aₖ .* c
251
+ # reset! = get_layer_resetter(layer)
252
+ # modify_layer!(rule, layer)
253
+ # ãₖ₊₁, pullback = Zygote.pullback(layer, modify_input(rule, aₖ))
254
+ # Rₖ .= aₖ .* only(pullback(Rₖ₊₁ ./ modify_denominator(rule, ãₖ₊₁)))
255
+ # reset!()
256
256
# end
257
257
# ```
258
258
#
259
- # You can see how `modify_layer` and `modify_denominator` dispatch on the rule and layer type.
260
- # This is how we implemented our own `MyGammaRule`.
259
+ # You can see how `modify_layer!`, `modify_input` and `modify_denominator` dispatch on the
260
+ # rule and layer type. This is how we implemented our own `MyGammaRule`.
261
261
# Unknown layers that are registered in the `LRP_CONFIG` use this exact function.
262
262
263
263
# ### Specialized implementations
@@ -267,7 +267,7 @@ analyzer = LRPZero(model)
267
267
# Reshaping layers don't affect attributions. We can therefore avoid the computational
268
268
# overhead of AD by writing a specialized implementation that simply reshapes back:
269
269
# ```julia
270
- # function lrp!(Rₖ, ::AbstractLRPRule , ::ReshapingLayer, aₖ, Rₖ₊₁)
270
+ # function lrp!(Rₖ, rule , ::ReshapingLayer, aₖ, Rₖ₊₁)
271
271
# Rₖ .= reshape(Rₖ₊₁, size(aₖ))
272
272
# end
273
273
# ```
@@ -276,14 +276,16 @@ analyzer = LRPZero(model)
276
276
#
277
277
# We can even implement the generic rule as a specialized implementation for `Dense` layers:
278
278
# ```julia
279
- # function lrp!(Rₖ, rule::AbstractLRPRule, layer::Dense, aₖ, Rₖ₊₁)
280
- # ρW, ρb = modify_params(rule, get_params(layer)...)
281
- # ãₖ₊₁ = modify_denominator(rule, ρW * aₖ + ρb)
282
- # @tullio Rₖ[j] = aₖ[j] * ρW[k, j] / ãₖ₊₁[k] * Rₖ₊₁[k] # Tullio ≈ fast einsum
279
+ # function lrp!(Rₖ, rule, layer::Dense, aₖ, Rₖ₊₁)
280
+ # reset! = get_layer_resetter(rule, layer)
281
+ # modify_layer!(rule, layer)
282
+ # ãₖ₊₁ = modify_denominator(rule, layer(modify_input(rule, aₖ)))
283
+ # @tullio Rₖ[j, b] = aₖ[j, b] * layer.weight[k, j] * Rₖ₊₁[k, b] / ãₖ₊₁[k, b] # Tullio ≈ fast einsum
284
+ # reset!()
283
285
# end
284
286
# ```
285
287
#
286
- # For maximum low-level control beyond `modify_layer`, `modify_params ` and `modify_denominator`,
288
+ # For maximum low-level control beyond `modify_layer! `, `modify_param! ` and `modify_denominator`,
287
289
# you can also implement your own `lrp!` function and dispatch
288
290
# on individual rule types `MyRule` and layer types `MyLayer`:
289
291
# ```julia
0 commit comments