Skip to content

Commit e359b98

Browse files
authored
In-place modify layers and rewrite ZBoxRule to use VJP (#73)
* In-place modify layers, rewrite `ZBoxRule` to use VJP and introduce `modify_input`. * Update tests and benchmarks around `TestWrapper` * Update references for new `ZBoxRule` * Update docs to new backend design
1 parent 02a4bb2 commit e359b98

File tree

12 files changed

+204
-139
lines changed

12 files changed

+204
-139
lines changed

benchmark/benchmarks.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ using BenchmarkTools
22
using LoopVectorization
33
using Flux
44
using ExplainableAI
5-
import ExplainableAI: modify_layer, lrp!
5+
import ExplainableAI: lrp!, modify_layer!, get_layer_resetter
66

77
on_CI = haskey(ENV, "GITHUB_ACTIONS")
88

@@ -51,7 +51,10 @@ struct TestWrapper{T}
5151
layer::T
5252
end
5353
(w::TestWrapper)(x) = w.layer(x)
54-
modify_layer(r::AbstractLRPRule, w::TestWrapper) = modify_layer(r, w.layer)
54+
modify_layer!(rule::R, w::TestWrapper) where {R} = modify_layer!(rule, w.layer)
55+
get_layer_resetter(rule::R, w::TestWrapper) where {R} = get_layer_resetter(rule, w.layer)
56+
get_layer_resetter(::ZeroRule, w::TestWrapper) = Returns(nothing)
57+
get_layer_resetter(::EpsilonRule, w::TestWrapper) = Returns(nothing)
5558
lrp!(Rₖ, rule::ZBoxRule, w::TestWrapper, aₖ, Rₖ₊₁) = lrp!(Rₖ, rule, w.layer, aₖ, Rₖ₊₁)
5659

5760
# generate input for conv layers

docs/literate/advanced_lrp.jl

Lines changed: 36 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -48,38 +48,39 @@ heatmap(input, analyzer)
4848
# The rule has to be of type `AbstractLRPRule`.
4949
struct MyGammaRule <: AbstractLRPRule end
5050

51-
# It is then possible to dispatch on the utility functions [`modify_params`](@ref) and [`modify_denominator`](@ref)
52-
# with the rule type `MyCustomLRPRule` to define custom rules without writing any boilerplate code.
51+
# It is then possible to dispatch on the utility functions [`modify_input`](@ref),
52+
# [`modify_param!`](@ref) and [`modify_denominator`](@ref) with the rule type
53+
# `MyCustomLRPRule` to define custom rules without writing any boilerplate code.
5354
# To extend internal functions, import them explicitly:
54-
import ExplainableAI: modify_params
55+
import ExplainableAI: modify_param!
5556

56-
function modify_params(::MyGammaRule, W, b)
57-
ρW = W + 0.25 * relu.(W)
58-
ρb = b + 0.25 * relu.(b)
59-
return ρW, ρb
57+
function modify_param!(::MyGammaRule, param)
58+
param .+= 0.25 * relu.(param)
59+
return nothing
6060
end
6161

6262
# We can directly use this rule to make an analyzer!
6363
analyzer = LRP(model, MyGammaRule())
6464
heatmap(input, analyzer)
6565

66-
# We just implemented our own version of the ``γ``-rule in 7 lines of code!
66+
# We just implemented our own version of the ``γ``-rule in 4 lines of code!
6767
# The outputs match perfectly:
6868
analyzer = LRP(model, GammaRule())
6969
heatmap(input, analyzer)
7070

71-
# If the layer doesn't use weights and biases `W` and `b`, ExplainableAI provides a
72-
# lower-level variant of [`modify_params`](@ref) called [`modify_layer`](@ref).
73-
# This function is expected to take a layer and return a new, modified layer.
71+
# If the layer doesn't use weights `layer.weight` and biases `layer.bias`,
72+
# ExplainableAI provides a lower-level variant of [`modify_param!`](@ref)
73+
# called [`modify_layer!`](@ref). This function is expected to take a layer
74+
# and return a new, modified layer.
7475

75-
#md # !!! warning "Using modify_layer"
76+
#md # !!! warning "Using modify_layer!"
7677
#md #
77-
#md # Use of the function `modify_layer` will overwrite functionality of `modify_params`
78+
#md # Use of the function `modify_layer!` will overwrite functionality of `modify_param!`
7879
#md # for the implemented combination of rule and layer types.
79-
#md # This is due to the fact that internally, `modify_params` is called by the default
80-
#md # implementation of `modify_layer`.
80+
#md # This is due to the fact that internally, `modify_param!` is called by the default
81+
#md # implementation of `modify_layer!`.
8182
#md #
82-
#md # Therefore it is recommended to only extend `modify_layer` for a specific rule
83+
#md # Therefore it is recommended to only extend `modify_layer!` for a specific rule
8384
#md # and a specific layer type.
8485

8586
# ## Custom layers and activation functions
@@ -202,7 +203,7 @@ analyzer = LRPZero(model)
202203
# The correct rule is applied via [multiple dispatch](https://www.youtube.com/watch?v=kc9HwsxE1OY)
203204
# on the types of the arguments `rule` and `layer`.
204205
# The relevance `Rₖ` is then computed based on the input activation `aₖ` and the output relevance `Rₖ₊₁`.
205-
# Multiple dispatch is also used to dispatch `modify_params` and `modify_denominator` on the rule and layer type.
206+
# Multiple dispatch is also used to dispatch `modify_param!` and `modify_denominator` on the rule and layer type.
206207
#
207208
# Calling `analyze` on a LRP-model applies a forward-pass of the model, keeping track of
208209
# the activations `aₖ` for each layer `k`.
@@ -215,7 +216,7 @@ analyzer = LRPZero(model)
215216
# R_{j}=\sum_{k} \frac{a_{j} \cdot \rho\left(w_{j k}\right)}{\epsilon+\sum_{0, j} a_{j} \cdot \rho\left(w_{j k}\right)} R_{k}
216217
# ```
217218
#
218-
# where ``\rho`` is a function that modifies parameters – what we have so far called `modify_params`.
219+
# where ``\rho`` is a function that modifies parameters – what we call `modify_param!`.
219220
#
220221
# The computation of this propagation rule can be decomposed into four steps:
221222
# ```math
@@ -243,21 +244,20 @@ analyzer = LRPZero(model)
243244

244245
# ### AD fallback
245246
# The default LRP fallback for unknown layers uses AD via [Zygote](https://github.com/FluxML/Zygote.jl).
246-
# For `lrp!`, we end up with something that looks very similar to the previous four step computation:
247+
# For `lrp!`, we implement the previous four step computation using `Zygote.pullback` to
248+
# compute ``c`` from the previous equation as a VJP, pulling back ``s_{k}=R_{k}/z_{k}``:
247249
# ```julia
248250
# function lrp!(Rₖ, rule, layer, aₖ, Rₖ₊₁)
249-
# layerᵨ = modify_layer(rule, layer)
250-
# c = gradient(aₖ) do a
251-
# z = layerᵨ(a)
252-
# s = Zygote.@ignore Rₖ₊₁ ./ modify_denominator(rule, z)
253-
# z ⋅ s
254-
# end |> only
255-
# Rₖ .= aₖ .* c
251+
# reset! = get_layer_resetter(layer)
252+
# modify_layer!(rule, layer)
253+
# ãₖ₊₁, pullback = Zygote.pullback(layer, modify_input(rule, aₖ))
254+
# Rₖ .= aₖ .* only(pullback(Rₖ₊₁ ./ modify_denominator(rule, ãₖ₊₁)))
255+
# reset!()
256256
# end
257257
# ```
258258
#
259-
# You can see how `modify_layer` and `modify_denominator` dispatch on the rule and layer type.
260-
# This is how we implemented our own `MyGammaRule`.
259+
# You can see how `modify_layer!`, `modify_input` and `modify_denominator` dispatch on the
260+
# rule and layer type. This is how we implemented our own `MyGammaRule`.
261261
# Unknown layers that are registered in the `LRP_CONFIG` use this exact function.
262262

263263
# ### Specialized implementations
@@ -267,7 +267,7 @@ analyzer = LRPZero(model)
267267
# Reshaping layers don't affect attributions. We can therefore avoid the computational
268268
# overhead of AD by writing a specialized implementation that simply reshapes back:
269269
# ```julia
270-
# function lrp!(Rₖ, ::AbstractLRPRule, ::ReshapingLayer, aₖ, Rₖ₊₁)
270+
# function lrp!(Rₖ, rule, ::ReshapingLayer, aₖ, Rₖ₊₁)
271271
# Rₖ .= reshape(Rₖ₊₁, size(aₖ))
272272
# end
273273
# ```
@@ -276,14 +276,16 @@ analyzer = LRPZero(model)
276276
#
277277
# We can even implement the generic rule as a specialized implementation for `Dense` layers:
278278
# ```julia
279-
# function lrp!(Rₖ, rule::AbstractLRPRule, layer::Dense, aₖ, Rₖ₊₁)
280-
# ρW, ρb = modify_params(rule, get_params(layer)...)
281-
# ãₖ₊₁ = modify_denominator(rule, ρW * aₖ + ρb)
282-
# @tullio Rₖ[j] = aₖ[j] * ρW[k, j] / ãₖ₊₁[k] * Rₖ₊₁[k] # Tullio ≈ fast einsum
279+
# function lrp!(Rₖ, rule, layer::Dense, aₖ, Rₖ₊₁)
280+
# reset! = get_layer_resetter(rule, layer)
281+
# modify_layer!(rule, layer)
282+
# ãₖ₊₁ = modify_denominator(rule, layer(modify_input(rule, aₖ)))
283+
# @tullio Rₖ[j, b] = aₖ[j, b] * layer.weight[k, j] * Rₖ₊₁[k, b] / ãₖ₊₁[k, b] # Tullio ≈ fast einsum
284+
# reset!()
283285
# end
284286
# ```
285287
#
286-
# For maximum low-level control beyond `modify_layer`, `modify_params` and `modify_denominator`,
288+
# For maximum low-level control beyond `modify_layer!`, `modify_param!` and `modify_denominator`,
287289
# you can also implement your own `lrp!` function and dispatch
288290
# on individual rule types `MyRule` and layer types `MyLayer`:
289291
# ```julia

docs/src/api.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,10 @@ ZBoxRule
3232
## Custom rules
3333
These utilities can be used to define custom rules without writing boilerplate code:
3434
```@docs
35+
modify_input
3536
modify_denominator
36-
modify_params
37-
modify_layer
37+
modify_param!
38+
modify_layer!
3839
LRP_CONFIG.supports_layer
3940
LRP_CONFIG.supports_activation
4041
```

src/ExplainableAI.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ using ColorSchemes
1616
using Markdown
1717
using PrettyTables
1818

19+
include("compat.jl")
1920
include("neuron_selection.jl")
2021
include("analyze_api.jl")
2122
include("types.jl")
@@ -42,7 +43,8 @@ export LRP, LRPZero, LRPEpsilon, LRPGamma
4243
export AbstractLRPRule
4344
export LRP_CONFIG
4445
export ZeroRule, EpsilonRule, GammaRule, ZBoxRule
45-
export modify_denominator, modify_params, modify_layer
46+
export modify_input, modify_denominator
47+
export modify_param!, modify_layer!
4648
export check_model
4749

4850
# heatmapping

src/compat.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# https://github.com/JuliaLang/julia/pull/39794
2+
if VERSION < v"1.7.0-DEV.793"
3+
export Returns
4+
5+
struct Returns{V} <: Function
6+
value::V
7+
Returns{V}(value) where {V} = new{V}(value)
8+
Returns(value) = new{Core.Typeof(value)}(value)
9+
end
10+
11+
(obj::Returns)(args...; kw...) = obj.value
12+
function Base.show(io::IO, obj::Returns)
13+
show(io, typeof(obj))
14+
print(io, "(")
15+
show(io, obj.value)
16+
return print(io, ")")
17+
end
18+
end

src/flux.jl

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -59,22 +59,9 @@ function strip_softmax(model::Chain)
5959
end
6060
return model
6161
end
62-
strip_softmax(l::Union{Dense,Conv}) = set_params(l, l.weight, l.bias, identity)
63-
64-
# helper function to work around `bias=false` (Flux v0.13) and `bias=Flux.Zeros` (v0.12)
65-
function get_params(layer)
66-
W = layer.weight
67-
b = layer.bias
68-
if b == false || typeof(b) <: Flux.Zeros
69-
b = zeros(eltype(W), size(W, 1))
70-
end
71-
return W, b
62+
strip_softmax(l::Dense) = Dense(l.weight, l.bias, identity)
63+
function strip_softmax(l::Conv)
64+
return Conv(identity, l.weight, l.bias, l.stride, l.pad, l.dilation, l.groups)
7265
end
7366

74-
"""
75-
set_params(layer, W, b)
76-
77-
Duplicate layer using weights W, b.
78-
"""
79-
set_params(l::Conv, W, b, σ=l.σ) = Conv(σ, W, b, l.stride, l.pad, l.dilation, l.groups)
80-
set_params(l::Dense, W, b, σ=l.σ) = Dense(W, b, σ)
67+
has_weight_and_bias(layer) = hasproperty(layer, :weight) && hasproperty(layer, :bias)

0 commit comments

Comments
 (0)