|
| 1 | +# # Advanced LRP usage |
| 2 | +# One of the design goals of ExplainabilityMethods.jl is to combine ease of use with |
| 3 | +# **extensibility** for the purpose of research. |
| 4 | +# |
| 5 | +# |
| 6 | +# This example will show you how to implement custom LRP rules and register custom layers |
| 7 | +# and activation functions. |
| 8 | +# |
| 9 | +# For this purpose, we will quickly load our model from the previous section: |
| 10 | +using ExplainabilityMethods |
| 11 | +using Flux |
| 12 | +using MLDatasets |
| 13 | +using ImageCore |
| 14 | +using BSON |
| 15 | + |
| 16 | +model = BSON.load("../model.bson", @__MODULE__)[:model] |
| 17 | + |
| 18 | +index = 10 |
| 19 | +x, y = MNIST.testdata(Float32, index) |
| 20 | +input = reshape(x, 28, 28, 1, :); |
| 21 | + |
| 22 | +# ## Custom LRP rules |
| 23 | +# Let's define a rule that modifies the weights and biases of our layer on the forward pass. |
| 24 | +# The rule has to be of type `AbstractLRPRule`. |
| 25 | +struct MyGammaRule <: AbstractLRPRule end |
| 26 | + |
| 27 | +# It is then possible to dispatch on the utility functions [`modify_params`](@ref) and [`modify_denominator`](@ref) |
| 28 | +# with our rule type `MyCustomLRPRule` to define custom rules without writing any boilerplate code. |
| 29 | +# to extend internal functions, import them explicitly: |
| 30 | +import ExplainabilityMethods: modify_params |
| 31 | + |
| 32 | +function modify_params(::MyGammaRule, W, b) |
| 33 | + ρW = W + 0.25 * relu.(W) |
| 34 | + ρb = b + 0.25 * relu.(b) |
| 35 | + return ρW, ρb |
| 36 | +end |
| 37 | + |
| 38 | +# We can directly use this rule to make an analyzer! |
| 39 | +analyzer = LRP(model, MyGammaRule()) |
| 40 | +heatmap(input, analyzer) |
| 41 | + |
| 42 | +# We just implemented our own version of the ``γ``-rule in 7 lines of code! |
| 43 | +# The outputs match perfectly: |
| 44 | +analyzer = LRP(model, GammaRule()) |
| 45 | +heatmap(input, analyzer) |
| 46 | + |
| 47 | +# If the layer doesn't use weights and biases `W` and `b`, ExplainabilityMethods provides a |
| 48 | +# lower-level variant of [`modify_params`](@ref) called [`modify_layer`](@ref). |
| 49 | +# This function is expected to take a layer and return a new, modified layer. |
| 50 | + |
| 51 | +#md # !!! warning "Using `modify_layer`" |
| 52 | +#md # |
| 53 | +#md # Use of the function `modify_layer` will overwrite functionality of `modify_params` |
| 54 | +#md # for the implemented combination of rule and layer types. |
| 55 | +#md # This is due to the fact that internally, `modify_params` is called by the default |
| 56 | +#md # implementation of `modify_layer`. |
| 57 | +#md # |
| 58 | +#md # Therefore it is recommended to only extend `modify_layer` for a specific rule |
| 59 | +#md # and a specific layer type. |
| 60 | + |
| 61 | +# ## Custom layers and activation functions |
| 62 | +# ### Model checks for humans |
| 63 | +# Good model checks and presets should allow novice users to apply XAI methods |
| 64 | +# in a "plug & play" manner according to best practices. |
| 65 | +# |
| 66 | +# Let's say we define a layer that doubles its input: |
| 67 | +struct MyDoublingLayer end |
| 68 | +(::MyDoublingLayer)(x) = 2 * x |
| 69 | + |
| 70 | +mylayer = MyDoublingLayer() |
| 71 | +mylayer([1, 2, 3]) |
| 72 | + |
| 73 | +# Let's append this layer to our model: |
| 74 | +model = Chain(model..., MyDoublingLayer()) |
| 75 | + |
| 76 | +# Creating an LRP analyzer, e.g. `LRPZero(model)`, will throw an `ArgumentError` |
| 77 | +# and print a summary of the model check in the REPL: |
| 78 | +# ```julia-repl |
| 79 | +# ┌───┬───────────────────────┬─────────────────┬────────────┬────────────────┐ |
| 80 | +# │ │ Layer │ Layer supported │ Activation │ Act. supported │ |
| 81 | +# ├───┼───────────────────────┼─────────────────┼────────────┼────────────────┤ |
| 82 | +# │ 1 │ flatten │ true │ — │ true │ |
| 83 | +# │ 2 │ Dense(784, 100, relu) │ true │ relu │ true │ |
| 84 | +# │ 3 │ Dense(100, 10) │ true │ identity │ true │ |
| 85 | +# │ 4 │ MyDoublingLayer() │ false │ — │ true │ |
| 86 | +# └───┴───────────────────────┴─────────────────┴────────────┴────────────────┘ |
| 87 | +# Layers failed model check |
| 88 | +# ≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡ |
| 89 | +# |
| 90 | +# Found unknown layers MyDoublingLayer() that are not supported by ExplainabilityMethods' LRP implementation yet. |
| 91 | +# |
| 92 | +# If you think the missing layer should be supported by default, please submit an issue (https://github.com/adrhill/ExplainabilityMethods.jl/issues). |
| 93 | +# |
| 94 | +# These model checks can be skipped at your own risk by setting the LRP-analyzer keyword argument skip_checks=true. |
| 95 | +# |
| 96 | +# [...] |
| 97 | +# ``` |
| 98 | + |
| 99 | +# LRP should only be used on "Deep ReLU" networks and ExplainabilityMethods doesn't |
| 100 | +# recognize `MyDoublingLayer` as a compatible layer. |
| 101 | +# By default, it will therefore return an error and a model check summary |
| 102 | +# instead of returning an incorrect explanation. |
| 103 | +# |
| 104 | +# However, if we know `MyDoublingLayer` is compatible with "Deep ReLU" networks, |
| 105 | +# we can register it to tell ExplainabilityMethods that it is ok to use. |
| 106 | +# This will be shown in the following section. |
| 107 | + |
| 108 | +#md # !!! warning "Skipping model checks" |
| 109 | +#md # |
| 110 | +#md # All model checks can be skipped at the user's own risk by setting the LRP-analyzer |
| 111 | +#md # keyword argument `skip_checks=true`. |
| 112 | + |
| 113 | +# ### Registering custom layers |
| 114 | +# The error in the model check will stop after registering our custom layer type |
| 115 | +# `MyDoublingLayer` as "supported" by ExplainabilityMethods. |
| 116 | +# |
| 117 | +# This is done using the function [`LRP_CONFIG.supports_layer`](@ref), which should be set to return `true`: |
| 118 | +LRP_CONFIG.supports_layer(::MyDoublingLayer) = true |
| 119 | + |
| 120 | +# Now we can create and run an analyzer without getting an error: |
| 121 | +analyzer = LRPZero(model) |
| 122 | +heatmap(input, analyzer) |
| 123 | + |
| 124 | +#md # !!! note "Registering functions" |
| 125 | +#md # |
| 126 | +#md # Flux's `Chains` can also contain functions, e.g. `flatten`. |
| 127 | +#md # This kind of layer can be registered as |
| 128 | +#md # ```julia |
| 129 | +#md # LRP_CONFIG.supports_layer(::typeof(mylayer)) = true |
| 130 | +#md # ``` |
| 131 | + |
| 132 | +# ### Registering activation functions |
| 133 | +# The mechanism for registering custom activation functions is analogous to that of custom layers: |
| 134 | +myrelu(x) = max.(0, x) |
| 135 | +model = Chain(flatten, Dense(784, 100, myrelu), Dense(100, 10)) |
| 136 | + |
| 137 | +# Once again, creating an LRP analyzer for this model will throw an `ArgumentError` |
| 138 | +# and display the following model check summary: |
| 139 | +# ```julia-repl |
| 140 | +# julia> analyzer = LRPZero(model3) |
| 141 | +# ┌───┬─────────────────────────┬─────────────────┬────────────┬────────────────┐ |
| 142 | +# │ │ Layer │ Layer supported │ Activation │ Act. supported │ |
| 143 | +# ├───┼─────────────────────────┼─────────────────┼────────────┼────────────────┤ |
| 144 | +# │ 1 │ flatten │ true │ — │ true │ |
| 145 | +# │ 2 │ Dense(784, 100, myrelu) │ true │ myrelu │ false │ |
| 146 | +# │ 3 │ Dense(100, 10) │ true │ identity │ true │ |
| 147 | +# └───┴─────────────────────────┴─────────────────┴────────────┴────────────────┘ |
| 148 | +# Activations failed model check |
| 149 | +# ≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡ |
| 150 | +# |
| 151 | +# Found layers with unknown or unsupported activation functions myrelu. LRP assumes that the model is a "deep rectifier network" that only contains ReLU-like activation functions. |
| 152 | +# |
| 153 | +# If you think the missing activation function should be supported by default, please submit an issue (https://github.com/adrhill/ExplainabilityMethods.jl/issues). |
| 154 | +# |
| 155 | +# These model checks can be skipped at your own risk by setting the LRP-analyzer keyword argument skip_checks=true. |
| 156 | +# |
| 157 | +# [...] |
| 158 | +# ``` |
| 159 | + |
| 160 | +# Registation works by defining the function [`LRP_CONFIG.supports_activation`](@ref) as `true`: |
| 161 | +LRP_CONFIG.supports_activation(::typeof(myrelu)) = true |
| 162 | + |
| 163 | +# now the analyzer can be created without error: |
| 164 | +analyzer = LRPZero(model) |
| 165 | + |
| 166 | +# ## How it works internally |
| 167 | +# Internally, ExplainabilityMethods dispatches to low level functions |
| 168 | +# ```julia |
| 169 | +# lrp!(rule, layer, Rₖ, aₖ, Rₖ₊₁) |
| 170 | +# ``` |
| 171 | +# These functions dispatch on rule and layer type and inplace-modify pre-allocated arrays `Rₖ` |
| 172 | +# based on the inputs `aₖ` and `Rₖ₊₁`. |
| 173 | +# |
| 174 | +# The default LRP fallback for unknown layers uses automatic differentiation (AD) via Zygote: |
| 175 | +# ```julia |
| 176 | +# function lrp!(rule, layer, Rₖ, aₖ, Rₖ₊₁) |
| 177 | +# layerᵨ = modify_layer(rule, layer) |
| 178 | +# c = gradient(aₖ) do a |
| 179 | +# z = layerᵨ(a) |
| 180 | +# s = Zygote.@ignore Rₖ₊₁ ./ modify_denominator(rule, z) |
| 181 | +# z ⋅ s |
| 182 | +# end |> only |
| 183 | +# Rₖ .= aₖ .* c |
| 184 | +# end |
| 185 | +# ``` |
| 186 | +# |
| 187 | +# Here you can clearly see how this AD-fallback dispatches on `modify_layer` and `modify_denominator` |
| 188 | +# based on the rule and layer type. This is how we implemented our own `MyGammaRule`! |
| 189 | +# Unknown layers that are registered in the `LRP_CONFIG` use this exact function. |
| 190 | +# |
| 191 | +# We can also implement versions of `lrp!` that are specialized for specific layer type. |
| 192 | +# For example, reshaping layers don't affect attributions, therefore no AD is required. |
| 193 | +# ExplainabilityMethods implements: |
| 194 | +# ```julia |
| 195 | +# function lrp!(rule, ::ReshapingLayer, Rₖ, aₖ, Rₖ₊₁) |
| 196 | +# Rₖ .= reshape(Rₖ₊₁, size(aₖ)) |
| 197 | +# end |
| 198 | +# ``` |
| 199 | +# |
| 200 | +# Even Dense layers have a specialized implementation: |
| 201 | +# ```julia |
| 202 | +# function lrp!(rule, layer::Dense, Rₖ, aₖ, Rₖ₊₁) |
| 203 | +# ρW, ρb = modify_params(rule, get_params(layer)...) |
| 204 | +# ãₖ₊₁ = modify_denominator(rule, ρW * aₖ + ρb) |
| 205 | +# @tullio Rₖ[j] = aₖ[j] * ρW[k, j] / ãₖ₊₁[k] * Rₖ₊₁[k] # Tullio = fast einsum |
| 206 | +# end |
| 207 | +# ``` |
| 208 | +# Just like in the LRP papers! |
| 209 | +# |
| 210 | +# For maximum low-level control, you can also implement your own `lrp!` function |
| 211 | +# and dispatch on individual rule types `MyRule` and layer types `MyLayer`: |
| 212 | +# ```julia |
| 213 | +# function lrp!(rule::MyRule, layer::MyLayer, Rₖ, aₖ, Rₖ₊₁) |
| 214 | +# Rₖ .= ... |
| 215 | +# end |
| 216 | +# ``` |
0 commit comments