Skip to content

Commit a403c45

Browse files
authored
Faster LRP rules on Dense layers (#31)
* Add faster AD-less LRP implementation for Dense layer using Tullio.jl
1 parent 0c4c786 commit a403c45

File tree

4 files changed

+28
-8
lines changed

4 files changed

+28
-8
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,15 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1111
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
1212
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
1313
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
14+
Tullio = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc"
1415
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1516

1617
[compat]
1718
ColorSchemes = "3"
1819
Flux = "0.12"
1920
ImageCore = "0.8, 0.9"
2021
PrettyTables = "1"
22+
Tullio = "0.3"
2123
Zygote = "0.6"
2224
julia = "1.6"
2325

src/ExplainabilityMethods.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ using Zygote
66
using ColorSchemes
77
using ImageCore
88
using Base.Iterators
9+
using Tullio
910

1011
using Markdown
1112
using PrettyTables

src/lrp_rules.jl

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,23 +4,27 @@
44
# can be implemented by dispatching on the functions `modify_params` & `modify_denominator`,
55
# which make use of the generalized LRP implementation shown in [1].
66
#
7-
# If the relevance propagation falls outside of this scheme, a custom function
7+
# If the relevance propagation falls outside of this scheme, custom functions
88
# ```julia
99
# (::MyLRPRule)(layer, aₖ, Rₖ₊₁) = ...
10+
# (::MyLRPRule)(layer::MyLayer, aₖ, Rₖ₊₁) = ...
11+
# (::AbstractLRPRule)(layer::MyLayer, aₖ, Rₖ₊₁) = ...
1012
# ```
11-
# can be implemented. This is used for the ZBoxRule.
13+
# that return `Rₖ` can be implemented.
14+
# This is used for the ZBoxRule and for faster computations on common layers.
1215
#
1316
# References:
1417
# [1] G. Montavon et al., Layer-Wise Relevance Propagation: An Overview
15-
# [2] W. Samek et al., Explaining Deep Neural Networks and Beyond:
16-
# A Review of Methods and Applications
18+
# [2] W. Samek et al., Explaining Deep Neural Networks and Beyond: A Review of Methods and Applications
1719

1820
abstract type AbstractLRPRule end
1921

2022
# This is the generic relevance propagation rule which is used for the 0, γ and ϵ rules.
2123
# It can be extended for new rules via `modify_denominator` and `modify_params`.
2224
# Since it uses autodiff, it is used as a fallback for layer types without custom implementation.
23-
function (rule::AbstractLRPRule)(layer, aₖ, Rₖ₊₁)
25+
(rule::AbstractLRPRule)(layer, aₖ, Rₖ₊₁) = lrp_autodiff(rule, layer, aₖ, Rₖ₊₁)
26+
27+
function lrp_autodiff(rule, layer, aₖ, Rₖ₊₁)
2428
layerᵨ = _modify_layer(rule, layer)
2529
function fwpass(a)
2630
z = layerᵨ(a)
@@ -30,7 +34,16 @@ function (rule::AbstractLRPRule)(layer, aₖ, Rₖ₊₁)
3034
return aₖ .* gradient(fwpass, aₖ)[1] # Rₖ
3135
end
3236

33-
# Special cases are dispatched on layer type:
37+
# For linear layer types such as Dense layers, using autodiff is overkill.
38+
(rule::AbstractLRPRule)(layer::Dense, aₖ, Rₖ₊₁) = lrp_dense(rule, layer, aₖ, Rₖ₊₁)
39+
40+
function lrp_dense(rule, l, aₖ, Rₖ₊₁)
41+
ρW, ρb = modify_params(rule, get_params(l)...)
42+
ãₖ₊₁ = modify_denominator(rule, ρW * aₖ + ρb)
43+
return @tullio Rₖ[j] := aₖ[j] * ρW[k, j] / ãₖ₊₁[k] * Rₖ₊₁[k]
44+
end
45+
46+
# Other special cases that are dispatched on layer type:
3447
(::AbstractLRPRule)(::DropoutLayer, aₖ, Rₖ₊₁) = Rₖ₊₁
3548
(::AbstractLRPRule)(::ReshapingLayer, aₖ, Rₖ₊₁) = reshape(Rₖ₊₁, size(aₖ))
3649

@@ -104,7 +117,10 @@ Commonly used on the first layer for pixel input.
104117
struct ZBoxRule <: AbstractLRPRule end
105118

106119
# The ZBoxRule requires its own implementation of relevance propagation.
107-
function (rule::ZBoxRule)(layer::Union{Dense,Conv}, aₖ, Rₖ₊₁)
120+
(rule::ZBoxRule)(layer::Dense, aₖ, Rₖ₊₁) = lrp_zbox(layer, aₖ, Rₖ₊₁)
121+
(rule::ZBoxRule)(layer::Conv, aₖ, Rₖ₊₁) = lrp_zbox(layer, aₖ, Rₖ₊₁)
122+
123+
function lrp_zbox(layer, aₖ, Rₖ₊₁)
108124
W, b = get_params(layer)
109125
l, h = fill.(extrema(aₖ), (size(aₖ),))
110126

src/utils.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
44
Replace zero terms of a matrix `d` with `eps`.
55
"""
6-
function stabilize_denom(d; eps=1.0f-9)
6+
stabilize_denom(d::Real; eps=1.0f-9) = ifelse(d 0, d + sign(d) * eps, d)
7+
function stabilize_denom(d::AbstractArray; eps=1.0f-9)
78
return d + ((d .≈ 0) + sign.(d)) * eps
89
end
910

0 commit comments

Comments
 (0)