Skip to content

Commit b4fb888

Browse files
authored
Make LRP rules type inferable (#33)
* introduce `lrp` for rule calls * help type inference and specialization in rule calls * test type stability with `@inferred` * add precompilation of LRP rules
1 parent f120801 commit b4fb888

File tree

7 files changed

+94
-44
lines changed

7 files changed

+94
-44
lines changed

benchmark/benchmarks.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using BenchmarkTools
22
using Flux
33
using ExplainabilityMethods
4-
import ExplainabilityMethods: _modify_layer
4+
import ExplainabilityMethods: _modify_layer, lrp
55

66
on_CI = haskey(ENV, "GITHUB_ACTIONS")
77

@@ -44,7 +44,7 @@ struct TestWrapper{T}
4444
end
4545
(w::TestWrapper)(x) = w.layer(x)
4646
_modify_layer(r::AbstractLRPRule, w::TestWrapper) = _modify_layer(r, w.layer)
47-
(rule::ZBoxRule)(w::TestWrapper, aₖ, Rₖ₊₁) = rule(w.layer, aₖ, Rₖ₊₁)
47+
lrp(rule::ZBoxRule, w::TestWrapper, aₖ, Rₖ₊₁) = lrp(rule, w.layer, aₖ, Rₖ₊₁)
4848

4949
# generate input for conv layers
5050
insize = (64, 64, 3, 1)
@@ -66,15 +66,13 @@ rules = Dict(
6666
"ZBoxRule" => ZBoxRule(),
6767
)
6868

69-
test_rule(rule, layer, aₖ, Rₖ₊₁) = rule(layer, aₖ, Rₖ₊₁) # for use with @benchmarkable macro
70-
7169
SUITE["Layer"] = BenchmarkGroup([k for k in keys(layers)])
7270
for (layername, (layer, aₖ)) in layers
7371
SUITE["Layer"][layername] = BenchmarkGroup([k for k in keys(rules)])
7472

7573
Rₖ₊₁ = layer(aₖ)
7674
for (rulename, rule) in rules
77-
SUITE["Layer"][layername][rulename] = @benchmarkable test_rule(
75+
SUITE["Layer"][layername][rulename] = @benchmarkable lrp(
7876
$(rule), $(layer), $(aₖ), $(Rₖ₊₁)
7977
)
8078
end

src/ExplainabilityMethods.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ include("lrp_checks.jl")
2020
include("lrp_rules.jl")
2121
include("lrp.jl")
2222
include("heatmap.jl")
23+
include("precompile.jl")
24+
_precompile_()
2325

2426
export analyze
2527

@@ -32,7 +34,7 @@ export LRP, LRPZero, LRPEpsilon, LRPGamma
3234
export AbstractLRPRule
3335
export LRP_CONFIG
3436
export ZeroRule, EpsilonRule, GammaRule, ZBoxRule
35-
export modify_params, modify_denominator
37+
export lrp, modify_params, modify_denominator
3638
export check_model
3739

3840
# heatmapping

src/lrp.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ function (analyzer::LRP)(input, ns::AbstractNeuronSelector; layerwise_relevances
6363

6464
# Backward pass through layers, applying LRP rules
6565
for (i, rule) in Iterators.reverse(enumerate(analyzer.rules))
66-
rels[i] .= rule(layers[i], acts[i], rels[i + 1]) # Rₖ = rule(layer, aₖ, Rₖ₊₁)
66+
rels[i] .= lrp(rule, layers[i], acts[i], rels[i + 1])
6767
end
6868

6969
if layerwise_relevances

src/lrp_rules.jl

Lines changed: 30 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -22,30 +22,38 @@ abstract type AbstractLRPRule end
2222
# This is the generic relevance propagation rule which is used for the 0, γ and ϵ rules.
2323
# It can be extended for new rules via `modify_denominator` and `modify_params`.
2424
# Since it uses autodiff, it is used as a fallback for layer types without custom implementation.
25-
(rule::AbstractLRPRule)(layer, aₖ, Rₖ₊₁) = lrp_autodiff(rule, layer, aₖ, Rₖ₊₁)
25+
function lrp(rule::R, layer::L, aₖ, Rₖ₊₁) where {R<:AbstractLRPRule,L}
26+
return lrp_autodiff(rule, layer, aₖ, Rₖ₊₁)
27+
end
2628

27-
function lrp_autodiff(rule, layer, aₖ, Rₖ₊₁)
29+
function lrp_autodiff(
30+
rule::R, layer::L, aₖ::T1, Rₖ₊₁::T2
31+
) where {R<:AbstractLRPRule,L,T1,T2}
2832
layerᵨ = _modify_layer(rule, layer)
29-
function fwpass(a)
30-
z = layerᵨ(a)
31-
s = Zygote.dropgrad(Rₖ₊₁ ./ modify_denominator(rule, z))
32-
return z s
33-
end
34-
return aₖ .* gradient(fwpass, aₖ)[1] # Rₖ
33+
c::T1 = only(
34+
gradient(aₖ) do a
35+
z::T2 = layerᵨ(a)
36+
s = Zygote.@ignore Rₖ₊₁ ./ modify_denominator(rule, z)
37+
z s
38+
end,
39+
)
40+
return aₖ .* c # Rₖ
3541
end
3642

3743
# For linear layer types such as Dense layers, using autodiff is overkill.
38-
(rule::AbstractLRPRule)(layer::Dense, aₖ, Rₖ₊₁) = lrp_dense(rule, layer, aₖ, Rₖ₊₁)
44+
function lrp(rule::R, layer::Dense, aₖ, Rₖ₊₁) where {R<:AbstractLRPRule}
45+
return lrp_dense(rule, layer, aₖ, Rₖ₊₁)
46+
end
3947

40-
function lrp_dense(rule, l, aₖ, Rₖ₊₁)
48+
function lrp_dense(rule::R, l, aₖ, Rₖ₊₁) where {R<:AbstractLRPRule}
4149
ρW, ρb = modify_params(rule, get_params(l)...)
4250
ãₖ₊₁ = modify_denominator(rule, ρW * aₖ + ρb)
4351
return @tullio Rₖ[j] := aₖ[j] * ρW[k, j] / ãₖ₊₁[k] * Rₖ₊₁[k]
4452
end
4553

4654
# Other special cases that are dispatched on layer type:
47-
(::AbstractLRPRule)(::DropoutLayer, aₖ, Rₖ₊₁) = Rₖ₊₁
48-
(::AbstractLRPRule)(::ReshapingLayer, aₖ, Rₖ₊₁) = reshape(Rₖ₊₁, size(aₖ))
55+
lrp(::AbstractLRPRule, ::DropoutLayer, aₖ, Rₖ₊₁) = Rₖ₊₁
56+
lrp(::AbstractLRPRule, ::ReshapingLayer, aₖ, Rₖ₊₁) = reshape(Rₖ₊₁, size(aₖ))
4957

5058
# To implement new rules, we can define two custom functions `modify_params` and `modify_denominator`.
5159
# If this isn't done, the following fallbacks are used by default:
@@ -65,7 +73,7 @@ modify_denominator(::AbstractLRPRule, d) = stabilize_denom(d; eps=1.0f-9) # gene
6573

6674
# This helper function applies `modify_params`:
6775
_modify_layer(::AbstractLRPRule, layer) = layer # skip layers without modify_params
68-
function _modify_layer(rule::AbstractLRPRule, layer::Union{Dense,Conv})
76+
function _modify_layer(rule::R, layer::L) where {R<:AbstractLRPRule,L<:Union{Dense,Conv}}
6977
return set_params(layer, modify_params(rule, get_params(layer)...)...)
7078
end
7179

@@ -117,26 +125,24 @@ Commonly used on the first layer for pixel input.
117125
struct ZBoxRule <: AbstractLRPRule end
118126

119127
# The ZBoxRule requires its own implementation of relevance propagation.
120-
(rule::ZBoxRule)(layer::Dense, aₖ, Rₖ₊₁) = lrp_zbox(layer, aₖ, Rₖ₊₁)
121-
(rule::ZBoxRule)(layer::Conv, aₖ, Rₖ₊₁) = lrp_zbox(layer, aₖ, Rₖ₊₁)
128+
lrp(::ZBoxRule, layer::Dense, aₖ, Rₖ₊₁) = lrp_zbox(layer, aₖ, Rₖ₊₁)
129+
lrp(::ZBoxRule, layer::Conv, aₖ, Rₖ₊₁) = lrp_zbox(layer, aₖ, Rₖ₊₁)
122130

123-
function lrp_zbox(layer, aₖ, Rₖ₊₁)
131+
function lrp_zbox(layer::L, aₖ::T1, Rₖ₊₁::T2) where {L,T1,T2}
124132
W, b = get_params(layer)
125133
l, h = fill.(extrema(aₖ), (size(aₖ),))
126134

127135
layer⁺ = set_params(layer, max.(0, W), max.(0, b)) # W⁺, b⁺
128136
layer⁻ = set_params(layer, min.(0, W), min.(0, b)) # W⁻, b⁻
129137

130-
# Forward pass
131-
function fwpass(a, l, h)
132-
f = layer(a)
133-
f⁺ = layer⁺(l)
134-
f⁻ = layer⁻(h)
138+
c::T1, cₗ::T1, cₕ::T1 = gradient(aₖ, l, h) do a, l, h
139+
f::T2 = layer(a)
140+
f⁺::T2 = layer⁺(l)
141+
f⁻::T2 = layer⁻(h)
135142

136143
z = f - f⁺ - f⁻
137-
s = Zygote.dropgrad(safedivide(Rₖ₊₁, z; eps=1e-9))
138-
return z s
144+
s = Zygote.@ignore safedivide(Rₖ₊₁, z; eps=1e-9)
145+
z s
139146
end
140-
c, cₗ, cₕ = gradient(fwpass, aₖ, l, h) # w.r.t. three inputs
141147
return aₖ .* c + l .* cₗ + h .* cₕ # Rₖ from backward pass
142148
end

src/precompile.jl

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
macro warnpcfail(ex::Expr)
2+
modl = __module__
3+
file = __source__.file === nothing ? "?" : String(__source__.file)
4+
line = __source__.line
5+
quote
6+
$(esc(ex)) || @warn """precompile directive $($(Expr(:quote, ex)))
7+
failed. Please report an issue in $($modl) (after checking for duplicates) or remove this directive.""" _file =
8+
$file _line = $line
9+
end
10+
end
11+
12+
function _precompile_()
13+
eltypes = (Float32,)
14+
ruletypes = (ZeroRule, EpsilonRule, GammaRule, ZBoxRule)
15+
layertypes = (
16+
Dense,
17+
Conv,
18+
MaxPool,
19+
AdaptiveMaxPool,
20+
GlobalMaxPool,
21+
MeanPool,
22+
AdaptiveMeanPool,
23+
GlobalMeanPool,
24+
DepthwiseConv,
25+
ConvTranspose,
26+
CrossCor,
27+
Dropout,
28+
AlphaDropout,
29+
typeof(Flux.flatten),
30+
)
31+
32+
for R in ruletypes
33+
for T in eltypes
34+
AT = Array{T}
35+
@warnpcfail precompile(modify_denominator, (R, AT))
36+
@warnpcfail precompile(modify_params, (R, AT, AT))
37+
38+
for L in layertypes
39+
@warnpcfail precompile(_modify_layer, (R, L))
40+
@warnpcfail precompile(lrp, (R, L, AT, AT))
41+
end
42+
end
43+
end
44+
end

test/test_neuron_selection.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
using ExplainabilityMethods: MaxActivationNS, IndexNS
22

33
A = [-2.1694243, 2.4023275, 0.99464744, -0.1514646, 1.0307171]
4-
ns1 = MaxActivationNS()
5-
ns2 = IndexNS(4)
4+
ns1 = @inferred MaxActivationNS()
5+
ns2 = @inferred IndexNS(4)
66

77
@test ns1(A) == 2
88
@test ns2(A) == 4

test/test_rules.jl

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using ExplainabilityMethods
22
using ExplainabilityMethods: modify_params
3-
import ExplainabilityMethods: _modify_layer
3+
import ExplainabilityMethods: _modify_layer, lrp
44
using Flux
55
using LinearAlgebra
66
using ReferenceTests
@@ -25,7 +25,7 @@ const RULES = Dict(
2525
Rₖ = [17 / 90, 316 / 675] # expected output
2626

2727
layer = Dense(W, b, relu)
28-
@test rule(layer, aₖ, Rₖ₊₁) Rₖ
28+
@test lrp(rule, layer, aₖ, Rₖ₊₁) Rₖ
2929

3030
## Pooling layer
3131
Rₖ₊₁ = Float32.([1 2; 3 4]//30)
@@ -38,7 +38,7 @@ const RULES = Dict(
3838
Rₖ = reshape(repeat(Rₖ, 1, 3), 3, 3, 3, 1)
3939

4040
layer = MaxPool((2, 2); stride=(1, 1))
41-
@test rule(layer, aₖ, Rₖ₊₁) Rₖ
41+
@test lrp(rule, layer, aₖ, Rₖ₊₁) Rₖ
4242
end
4343

4444
# Fixed pseudo-random numbers
@@ -48,7 +48,7 @@ pseudorandn(dims...) = randn(MersenneTwister(123), T, dims...)
4848
## Test individual rules
4949
@testset "modify_params" begin
5050
W, b = [1.0 -1.0; 2.0 0.0], [-1.0, 1.0]
51-
ρW, ρb = modify_params(GammaRule(; γ=0.42), W, b)
51+
ρW, ρb = @inferred modify_params(GammaRule(; γ=0.42), W, b)
5252
@test ρW [1.42 -1.0; 2.84 0.0]
5353
@test ρb [-1.0, 1.42]
5454
end
@@ -69,7 +69,7 @@ layers = Dict(
6969
for (layername, layer) in layers
7070
@testset "$layername" begin
7171
Rₖ₊₁ = layer(aₖ)
72-
Rₖ = rule(layer, aₖ, Rₖ₊₁)
72+
Rₖ = @inferred lrp(rule, layer, aₖ, Rₖ₊₁)
7373

7474
@test typeof(Rₖ) == typeof(aₖ)
7575
@test size(Rₖ) == size(aₖ)
@@ -110,8 +110,8 @@ equalpairs = Dict( # these pairs of layers are all equal
110110
l1, l2 = layers
111111
Rₖ₊₁ = l1(aₖ)
112112
@test Rₖ₊₁ == l2(aₖ)
113-
Rₖ = rule(l1, aₖ, Rₖ₊₁)
114-
@test Rₖ == rule(l2, aₖ, Rₖ₊₁)
113+
Rₖ = @inferred lrp(rule, l1, aₖ, Rₖ₊₁)
114+
@test Rₖ == lrp(rule, l2, aₖ, Rₖ₊₁)
115115

116116
@test typeof(Rₖ) == typeof(aₖ)
117117
@test size(Rₖ) == size(aₖ)
@@ -143,7 +143,7 @@ layers = Dict(
143143
for (layername, layer) in layers
144144
@testset "$layername" begin
145145
Rₖ₊₁ = layer(aₖ)
146-
Rₖ = rule(layer, aₖ, Rₖ₊₁)
146+
Rₖ = @inferred lrp(rule, layer, aₖ, Rₖ₊₁)
147147

148148
@test typeof(Rₖ) == typeof(aₖ)
149149
@test size(Rₖ) == size(aₖ)
@@ -158,13 +158,13 @@ layers = Dict(
158158
end
159159

160160
## Test custom layers & default AD fallback using the ZeroRule
161-
## Compare with references of non-wrapped layers
161+
# Compare with references of non-wrapped layers
162162
struct TestWrapper{T}
163163
layer::T
164164
end
165165
(w::TestWrapper)(x) = w.layer(x)
166166
_modify_layer(r::AbstractLRPRule, w::TestWrapper) = _modify_layer(r, w.layer)
167-
(rule::ZBoxRule)(w::TestWrapper, aₖ, Rₖ₊₁) = rule(w.layer, aₖ, Rₖ₊₁)
167+
lrp(rule::ZBoxRule, w::TestWrapper, aₖ, Rₖ₊₁) = lrp(rule, w.layer, aₖ, Rₖ₊₁)
168168

169169
layers = Dict(
170170
"Conv" => (Conv((3, 3), 2 => 4; init=pseudorandn), aₖ),
@@ -179,7 +179,7 @@ layers = Dict(
179179
@testset "$layername" begin
180180
wrapped_layer = TestWrapper(layer)
181181
Rₖ₊₁ = wrapped_layer(aₖ)
182-
Rₖ = rule(wrapped_layer, aₖ, Rₖ₊₁)
182+
Rₖ = @inferred lrp(rule, wrapped_layer, aₖ, Rₖ₊₁)
183183

184184
@test typeof(Rₖ) == typeof(aₖ)
185185
@test size(Rₖ) == size(aₖ)

0 commit comments

Comments
 (0)